使用kNN算法对魔方颜色进行分类
1. 数据处理
使用opencv中的方法读取魔方6个面的图片数据,对图像数据进行分割得到魔方6个面的像素数据。由于图像有噪点,故使用高斯平滑对图像进行平滑去噪得到魔方图像数据。
图一 原始图片
图二 处理后图片
2. 数据获取
对魔方每个面中心点尽心随机采样得到训练数据,对其他各个色块进行随机采样作为需预测数据。
3. KNN算法
所谓K近邻算法,即是给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的K个实例(也就是上面所说的K个邻居),这K个实例的多数属于某个类,就把该输入实例分类到这个类中。
图三 kNN示意图
本文所使用具体方法:对每个像素点进行曼哈顿距离计算,作为数据的相似性;统计每个色块的识别数量并进行判断得到最终的识别结果判断。
4. 代码实现
#URFDLB
import numpy as np
import cv2
import time
#高斯模糊核
Kernel = (31, 31)
#伽马
R = 2
#邻近数据个数
K = 5
#中心色块采样个数
N = 25
#魔方每个面中心坐标
face_XY = ((280, 380), (280, 380), (280, 380), (280, 380), (280, 380), (280, 380))
#魔方单个色块大小
block_len = (150, 150, 150, 150, 150, 150)
#BGR/HSV权重
w = (1, 1, 1)
np.random.seed() #随机数种子
#初始化
def init():
imgU = cv2.imread('imgU.jpg')
imgR = cv2.imread('imgR.jpg')
imgF = cv2.imread('imgF.jpg')
imgD = cv2.imread('imgD.jpg')
imgL = cv2.imread('imgL.jpg')
imgB = cv2.imread('imgB.jpg')
#截取坐标计算
pic_XY = np.zeros((6, 4), np.uint16)
for i in range(6):
pic_XY[i][0] = int(face_XY[i][0] - block_len[i]*1.5)
pic_XY[i][1] = int(face_XY[i][0] + block_len[i]*1.5)
pic_XY[i][2] = int(face_XY[i][1] - block_len[i]*1.5)
pic_XY[i][3] = int(face_XY[i][1] + block_len[i]*1.5)
#截取图片
img_cube = []
img_cube.append(cv2.GaussianBlur(imgU[pic_XY[0][0]:pic_XY[0][1], pic_XY[0][2]:pic_XY[0][3]], Kernel, R))
img_cube.append(cv2.GaussianBlur(imgR[pic_XY[1][0]:pic_XY[1][1], pic_XY[1][2]:pic_XY[1][3]], Kernel, R))
img_cube.append(cv2.GaussianBlur(imgF[pic_XY[2][0]:pic_XY[2][1], pic_XY[2][2]:pic_XY[2][3]], Kernel, R))
img_cube.append(cv2.GaussianBlur(imgD[pic_XY[3][0]:pic_XY[3][1], pic_XY[3][2]:pic_XY[3][3]], Kernel, R))
img_cube.append(cv2.GaussianBlur(imgL[pic_XY[4][0]:pic_XY[4][1], pic_XY[4][2]:pic_XY[4][3]], Kernel, R))
img_cube.append(cv2.GaussianBlur(imgB[pic_XY[5][0]:pic_XY[5][1], pic_XY[5][2]:pic_XY[5][3]], Kernel, R))
return np.array(img_cube)
#中心色块随机采样,urfdlb
def sample_center_block(img_cube):
sample = np.zeros((6, N, 3), np.uint8)
for i in range(6):
loc = np.random.randint(block_len[i]/(-4), block_len[i]/4, (N, 2), np.int16) #随机采样坐标偏移
for j in range(N):
x = int(loc[j][0] + img_cube[i].shape[0]/2)
y = int(loc[j][1] + img_cube[i].shape[1]/2)
sample[i][j] = img_cube[i][x][y]
return sample
#return cv2.cvtColor(sample, cv2.COLOR_BGR2HSV)
#待识别色块采样
def sample_identify_block(img_cube):
sample = np.zeros((6, 8, 3), np.uint8)
for i in range(6):
#中心坐标
x_cen = img_cube[i].shape[0]/2
y_cen = img_cube[i].shape[1]/2
loc = np.random.randint(block_len[i]/(-4), block_len[i]/4, (8, 2), np.int16) #随机采样坐标偏移
sample[i][0] = img_cube[i][int(x_cen-block_len[i]+loc[0][0])][int(y_cen-block_len[i]+loc[0][1])] #左上
sample[i][1] = img_cube[i][int(x_cen-block_len[i]+loc[1][0])][int(y_cen+loc[1][1])] #上
sample[i][2] = img_cube[i][int(x_cen-block_len[i]+loc[2][0])][int(y_cen+block_len[i]+loc[2][1])] #右上
sample[i][3] = img_cube[i][int(x_cen+loc[3][0])][int(y_cen-block_len[i]+loc[3][1])] #左
sample[i][4] = img_cube[i][int(x_cen+loc[4][0])][int(y_cen+block_len[i]+loc[4][1])] #右
sample[i][5] = img_cube[i][int(x_cen+block_len[i]+loc[5][0])][int(y_cen-block_len[i]+loc[5][1])] #左下
sample[i][6] = img_cube[i][int(x_cen+block_len[i])+loc[6][0]][int(y_cen+loc[6][1])] #下
sample[i][7] = img_cube[i][int(x_cen+block_len[i]+loc[7][0])][int(y_cen+block_len[i]+loc[7][1])] #右下
return sample
#return cv2.cvtColor(sample, cv2.COLOR_BGR2HSV)
#颜色分类
def identify_color(reference, identify):
result = '' #识别结果
test = np.zeros(6, np.uint8) #结果检测
order = 'URFDLB' #魔方顺序
for i in range(6): #6个面
for j in range(8): #8个色块
#色块对比
tmp = [] #曼哈顿距离列表
for r in range(6): #对比6个参照面
for k in range(N): #对比N个参考采样
len = 0
for c in range(3): #BGR
if identify[i][j][c] > reference[r][k][c]:
len += w[c] * (identify[i][j][c] - reference[r][k][c])
else:
len += w[c] * (reference[r][k][c] - identify[i][j][c])
tmp.append((len, r))
#K-最邻近统计
tmp.sort() #升序,前K个
k_len = np.zeros(6, np.uint8) #K-最邻近方块统计数组
for m in range(K):
k_len[tmp[m][1]] += 1
#找最大下标
max_index = 0
for mi in range(6):
if k_len[max_index] < k_len[mi]:
max_index = mi
#结果记录
result += order[max_index]
test[max_index] += 1
#中心色块
if j == 3:
result += order[i]
test[i] += 1
#结果检测
if min(test) == 9 and max(test) == 9:
return True, result
else:
return False, result
def main():
a = time.time()
cube = init()
c = 0
for i in range(100):
reference_cube = sample_center_block(cube)
identify_cube = sample_identify_block(cube)
flag, cube_res = identify_color(reference_cube, identify_cube)
if flag == True:
c += 1
print(c/100.0)
#print(time.time() - a)
if __name__ == '__main__':
main()
5. 结果分析
图四 预测准确率
引申:改变K值可能会得到不同结果;由于随机采样数据,结果可能与所不同。