#使用SepalLengthCm、SepalWidthCm、PetalLengthCm预测PetalWidthCm
import numpy as np
import pandas as pd
data = pd.read_csv("iris.csv")
#去掉不需要的id列和Species列
data.drop(['id','Species'],axis=1,inplace=True)
#去掉重复数据
data.drop_duplicates(inplace=True)
display(data)
class KNN:
"""使用python实现K近邻算法(回归预测)
该算法用于回归预测,根据前3个特征属性,寻找最近的k个邻居,然后再根据k各邻居的第四个特征属性,
去预测当前样本的第四个特征
"""
def __init__(self,k):
"""初始化方法
Parameters
----
k:int
邻居的个数
"""
self.k = k
def fit(self,X,y):
"""训练方法
Parameters
----
X:类似数组类型(特征矩阵),形状为[样本数量,特征数量]
待训练的样本特征(属性)
y:类似数组类型(目标标签),形状为[样本数量]
每个样本的目标值(标签)
"""
#将X与y转换成ndarray数组的形式,方便统一进行操作
self.X = np.asarray(X)
self.y = np.asarray(y)
def predict(self,X):
"""根据参数传递的X,对样本进行预测
Parameters:
----
X:类似数组的类型,形状为[样本数量,特征数量]
带预测样本的特征(属性)
Return:
----
result:数组类型
预测的结果
"""
#将X转换成数组类型
X = np.asarray(X)
#保存预测的结果值
result = []
for x in X:
#计算距离(计算与训练集中每个X的距离)
dis = np.sqrt(np.sum((x - self.X) ** 2,axis = 1))
#数组排序后,每个元素在原数组(排序之前的数组)中的索引
index = dis.argsort()
#取前k个距离最近的索引(在原数组中的索引)
index = index[:self.k]
#计算均值,加入到返回的记过列表中
result.append(np.mean(self.y[index]))
return np.asarray(result)
def predict2(self,X):
"""根据参数传递的X,对样本进行预测,(考虑权重)
权重的计算方式:每个节点(邻居)距离的倒数 / 所有节点距离倒数之和
Parameters:
----
X:类似数组的类型,形状为[样本数量,特征数量]
带预测样本的特征(属性)
Return:
----
result:数组类型
预测的结果
"""
#将X转换成数组类型
X = np.asarray(X)
#保存预测的结果值
result = []
for x in X:
#计算距离(计算与训练集中每个X的距离)
dis = np.sqrt(np.sum((x - self.X) ** 2,axis = 1))
#数组排序后,每个元素在原数组(排序之前的数组)中的索引
index = dis.argsort()
#取前k个距离最近的索引(在原数组中的索引)
index = index[:self.k]
#求所有邻居节点距离的倒数之和。注意最后加上一个很小的值,是为了避免除数为0的情况
s = np.sum(1 / (dis[index] + 0.0001))
weight = (1 / (dis[index] + 0.0001)) / s
#计算均值,加入到返回的记过列表中
result.append(np.sum(self.y[index] * weight))
return np.asarray(result)
#测试KNN预测情况
t = data.sample(len(data),random_state=0)
train_X = t.iloc[:120,:-1]
train_y = t.iloc[:120,-1]
test_X = t.iloc[120:,:-1]
test_y = t.iloc[120:,-1]
knn = KNN(3)
knn.fit(train_X,train_y)
result = knn.predict(test_X)
#衡量回归准确性
m = np.mean((result - test_y) **2)
display(m)
display(result)
display(test_y.values)
#可视化展示
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['font.family']="simHei"
mpl.rcParams['axes.unicode_minus'] = False
plt.figure(figsize=(10,10))
plt.plot(result,"ro-",label="预测值")
plt.plot(test_y.values,'go--',label="真实值")
plt.title("KNN连续值预测")
plt.xlabel("节点序号")
plt.ylabel("花瓣宽度")
plt.legend()
plt.show()