K近邻算法3 识别手写数字 from机器学习实战
补档最后一个例子
识别手写数字
前面KNN算法已经写好了,这个栗子也没有用到归一化处理
我们的数据集、测试集是3232的文本文档,用0/1表示
文件名0-1.txt代表数字0的第一个数据,因此可以通过字符串切片来获得label标签
处理有两部分:
1.将图像格式化处理为一个向量
把3232的向量变成1*1024
'''
来解决手写识别系统部分
把32*32处理成1*1024的向量
'''
def getfile2(filename):
returnVect = np.zeros((1,1024))
fr=open(filename)
for i in range(32):
linestr = fr.readline()
for j in range(32):
returnVect[0,i*32+j]=int(linestr[j])
return returnVect
2.将大批量的文件导入,形成np.matrix类型的训练集和测试集
这里用到了os模块中的listdir函数,作用是获得该目录下的所有的文件名
from os import listdir
总体的代码如下
import numpy as np
import matplotlib.pyplot as plt
import operator
from os import listdir
'''
K近邻算法实现
inX:用于分类的向量
dataSet数据集
labels标签
k 就是取前K个
'''
def calssify0(inX, dataSet, labels, k):
datasetsize = dataSet.shape[0]
diffmat = np.tile(inX, (datasetsize, 1)) - dataSet
sqdiffmat = diffmat ** 2
# 每一行求和,得到的是每一行的和(也就是距离平方)
sqdistance = sqdiffmat.sum(axis=1)
distance = sqdistance ** 0.5
# 上述过程计算出了距离
# 排序
# 建立一个空字典
sortedDistanceIndices = distance.argsort()
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistanceIndices[i]]
# get(key[, default])方法,返回序号(否则返回0)这样,就建立了最近的排名的一个映射
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
'''
来解决手写识别系统部分
把32*32处理成1*1024的向量
'''
def getfile2(filename):
returnVect = np.zeros((1,1024))
fr=open(filename)
for i in range(32):
linestr = fr.readline()
for j in range(32):
returnVect[0,i*32+j]=int(linestr[j])
return returnVect
'''
手写分类的测试
这里要做的额外工作是
读取每个文件,从文件名中提取出labels来
把每组向量合并在一起成为矩阵,
这样获得了测试集和训练集
'''
def handwritingclasstest():
#空列表记录标签
hwlabel=[]
trainingfilelist= listdir("D:\\pyt_example\\01\\venv\\digits\\trainingDigits")
m =len(trainingfilelist)
trainMat =np.zeros((m,1024))
for i in range(m):
filename= trainingfilelist[i]
filestr = filename.split('.')[0]
labelnumber =int(filename.split('_')[0])
hwlabel.append(labelnumber)
trainMat[i,:]=getfile2("digits\\trainingDigits\\%s" % filename )
#处理测试集
testfilelist =listdir("D:\\pyt_example\\01\\venv\\digits\\testDigits")
errorcount=0.0 #计算错误率
mtest =len(testfilelist)
for i in range(mtest):
testname =testfilelist[i]
teststr =testname.split('.')[0]
testlabel =int(testname.split('_')[0])
testMat =getfile2("digits\\testDigits\\%s" % testname)
classifyresult=calssify0(testMat, trainMat, hwlabel, 3)
print("case %d: the classifier came back with :%d,the real answer is %d" \
% (i+1,classifyresult,testlabel))
if(classifyresult!=testlabel):
errorcount+=1
print("error rate = %.2f%%" % (errorcount*100/float(mtest)))
运行效果如下图:
可以看到,利用KNN算法,我们得到了错误率1.06%的识别手写数字的算法
改进思路:
1.导入文件很慢
2.每个距离运算有1024个维度的浮点计算,每个测试用例需要算2000次,需要执行900次,运算量大