鸢尾花分类之网格搜索与交叉验证
小编有话说:其实网格搜索与交叉验证已经被封装到方法中,只需要我们在合适的场景中去运用即可,同时原理我们要搞清楚,这样会使我们能够更加灵活的使用。
#导入鸢尾花数据 from sklearn.datasets import load_iris #导入划分数据集方法 from sklearn.model_selection import train_test_split #导入标准化方法 from sklearn.preprocessing import StandardScaler #导入knn算法模块 from sklearn.neighbors import KNeighborsClassifier #导入网格搜索模块 from sklearn.model_selection import GridSearchCV def knn(): """ 通过knn对鸢尾花进行分类 :return: None """ # 评1,获取数据 iris = load_iris() # 2,数据集划分 x_train,x_test,y_train,y_test = train_test_split(iris.data,iris.target,random_state=6) # 3,特征工程:标准化 tansfer = StandardScaler() x_train = tansfer.fit_transform(x_train) x_test = tansfer.transform(x_test) # 4,knn预测流程 k = KNeighborsClassifier() # k.fit(x_train,y_train) # # 5,模型估 # #方法一 直接比对真实值与预测值 # predict = k.predict(x_test) # print("predict: \n" , predict) # print("比对真实值与预测值:\n", y_test == predict) # #方法二 计算准确率 # score = k.score(x_test,y_test) # print(score) #网格搜索 #构造一些k值的参数 param = {"n_neighbors" : [3,5,10,7]} #实例化 ss = GridSearchCV(k,param_grid=param,cv=2) ss.fit(x_train,y_train) #预测准确率 print("在测试集上的准确率:", ss.score(x_test,y_test)) print("在交叉验证中最好的结果:", ss.best_score_) print("选择的最好的模型:", ss.best_estimator_) print("每个超参数每次交叉验证的结果:", ss.cv_results_) return None if __name__ == "__main__": knn()
实验结果: