用鸢尾花数据集以及原生python实现knn分类算法。
- 作业题目
用鸢尾花数据集以及原生python实现knn分类算法。
- 算法设计
KNN算法中使用的是欧氏距离。
二维空间两点欧氏距离计算公式:
多维空间欧氏距离计算公式:
- 源代码
"""
@文件:鸢尾花识别
@说明:鸢尾花识别
@作者:王佳磊
@时间:2019.10.1
"""
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
class KNN:
# 设置近邻数
def __init__(self, k):
self.n_neighbors = k
# 训练
def fit(self, x_train, y_train):
# X是NXD的数组,其中每一行代表一个样本,Y是N行的一维数组,对应X的标签
# 最近邻分类器就是简单的记住所有的数据
self.xtr = x_train
self.ytr = y_train
# 预测
def predict(self, x_test):
num_test = x_test.shape[0]
# 确认输出的结果类型符合输入的类型
ypred = np.zeros(num_test, dtype=self.ytr.dtype)
# 循环每一行,也就是每一个样本
for i in range(num_test):
# 计算待分类数据与训练集各数据点的距离(计算欧式距离)
distances = np.sqrt(np.sum(np.square(self.xtr - x_test[i, :]), axis=1))
min_index = np.argmin(distances) # 拿到最小那个距离的索引
ypred[i] = self.ytr[min_index] # 预测样本的标签,其实就是跟他最近的训练数据样本的标签
return ypred
if __name__ == '__main__':
# 获取鸢尾花数据
iris_dataset = pd.read_csv("iris.data.csv")
# 载入特征和标签集
X = iris_dataset[['E_Length', 'E_Width', 'B_Length', 'B_Width']]
y = iris_dataset['Species']
# 划分数据集
X_train, X_test, Y_train, Y_test = train_test_split(X, y, test_size=0.3, random_state=0)
# 创建KNN实例,k为1
knn = KNN(1)
# 训练数据
knn.fit(X_train, Y_train)
# 预测
e_length = float(input("请输入花萼长:"))
e_width = float(input("请输入花萼宽:"))
b_length = float(input("请输入花瓣长:"))
b_width = float(input("请输入花瓣宽:"))
X_new = np.array([[e_length, e_width, b_length, b_width]])
print("X_new.shape:{}".format(X_new.shape))
prediction = knn.predict(X_new)
print("Prediction:{}".format(prediction))
if prediction == 0:
print("Prediction target name:setosa")
elif prediction == 1:
print("Prediction target name:versicolour")
elif prediction == 2:
print("Prediction target name:virginica")
- 调试、测试及运行结果
调试:
数据集划分
设置k值,初始化近邻数
模型训练
模型预测
测试及运行结果:
1.测试样本信息:花萼长:5cm宽2.9cm,
花瓣长:1cm宽0.2cm.
结果如下:
2.测试样本信息:花萼长:6.2cm宽3.4cm,
花瓣长:5.4cm宽2.3cm.
结果如下:
- 总结
1.knn原理的简单理解
当要判断一个新的值的类别时,根据他最近的k个点的类别进行判断。即“近朱者赤、近墨者黑”。
2.本次实验遇到了的问题:
在模块导入的情况下出现此错误:ImportError: DLL load failed: 找不到指定的模块。
原因:被使用模块以及被使用模块的依赖模块未从同一地方下载,导致下载文件的文件后缀不同,python在此种情况下易出现上述异常。将程序中使用到的模块全部卸载并统一来源重新安装,即可通过。
3.未解决的问题:
参考链接:https://www.cnblogs.com/douzujun/p/9035923.html