EM算法处理鸢尾花数据实战
今天我们用EM算法对鸢尾花数据进行分类处理,EM算法的原理较为复杂,我会总结之后再发出来。我们先来实战看一下EM算法的强大之处。
EM算法是无监督的分类,而我们的鸢尾花数据是已知类别的,所以我们在处理时直接忽略掉类别之一列,任务三个特征是符合三个独立的高斯分布混合得到,仅仅通过分析特征数据的均值、方差,来判断出这三个类别。
1.首先导入包
import numpy as np
from sklearn.mixture import GaussianMixture
import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import pairwise_distances_argmin
2.导入数据并进行初始化
if __name__ == '__main__':
path = '8.iris.data' # 数据文件路径
data = np.loadtxt(path, dtype=float, delimiter=',', converters={4: iris_type})
# 将数据的0到3列组成x,第4列得到y
x_prime, y = np.split(data, (4,), axis=1)
y = y.ravel()
#print(x_prime)
n_components = 3
feature_pairs = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]
plt.figure(figsize=(10, 9), facecolor='#FFFFFF')
3.对四个特征两两组合进行预测
for k, pair in enumerate(feature_pairs):
x = x_prime[:, pair]
#print(x) # y是目标值的列向量 它等于分类0,1,2时的值对应的X的位置 就可以算出每一类的实际均值
m = np.array([np.mean(x[y == i], axis=0) for i in range(3)]) # 均值的实际值
print ('实际均值 = \n', m)
gmm = GaussianMixture(n_components=n_components, covariance_type='full', random_state=0)
gmm.fit(x)
print ('预测均值 = \n', gmm.means_)
print ('预测方差 = \n', gmm.covariances_)
y_hat = gmm.predict(x)
order = pairwise_distances_argmin(m, gmm.means_, axis=1, metric='euclidean')
print ('顺序:\t', order)
这里主要算出实际均值的方法,然后带入高斯混合模型 算出训练后的预测均值和方差
但是由于算出的均值和实际均值的顺序可能不能一一对应 所以要看看顺序 即order
得到顺序为
顺序: [0 2 1]
可以看到并不是0.1.2 即预测的第二个均值 实则是我们实际的第三个特征的均值 ,所以我们要做一点小转换
n_sample = y.size
n_types = 3
change = np.empty((n_types, n_sample), dtype=np.bool)
for i in range(n_types):
change[i] = y_hat == order[i]
for i in range(n_types):
y_hat[change[i]] = i
acc = u'准确率:%.2f%%' % (100*np.mean(y_hat == y))
print (acc)
通过此操作可以将顺序变正并且打印出准确率
4.画图
cm_light = mpl.colors.ListedColormap(['#FF8080', '#77E0A0', '#A0A0FF'])
cm_dark = mpl.colors.ListedColormap(['r', 'g', '#6060FF'])
x1_min, x1_max = x[:, 0].min(), x[:, 0].max()
x2_min, x2_max = x[:, 1].min(), x[:, 1].max()
x1_min, x1_max = expand(x1_min, x1_max)
x2_min, x2_max = expand(x2_min, x2_max)
x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]
grid_test = np.stack((x1.flat, x2.flat), axis=1)
grid_hat = gmm.predict(grid_test)
change = np.empty((n_types, grid_hat.size), dtype=np.bool)
for i in range(n_types):
change[i] = grid_hat == order[i]
for i in range(n_types):
grid_hat[change[i]] = i
grid_hat = grid_hat.reshape(x1.shape)
plt.subplot(3, 2, k+1)
plt.pcolormesh(x1, x2, grid_hat, cmap=cm_light)
plt.scatter(x[:, 0], x[:, 1], s=30, c=y, marker='o', cmap=cm_dark, edgecolors='k')
xx = 0.95 * x1_min + 0.05 * x1_max
yy = 0.1 * x2_min + 0.9 * x2_max
plt.text(xx, yy, acc, fontsize=14)
plt.xlim((x1_min, x1_max))
plt.ylim((x2_min, x2_max))
plt.xlabel(iris_feature[pair[0]], fontsize=14)
plt.ylabel(iris_feature[pair[1]], fontsize=14)
plt.grid()
plt.tight_layout(2)
plt.suptitle(u'EM算法无监督分类鸢尾花数据', fontsize=20)
plt.subplots_adjust(top=0.92)
plt.show()
之后可以得到打印结果
实际均值 =
[[1.464 0.244]
[4.26 1.326]
[5.552 2.026]]
预测均值 =
[[1.4639995 0.24399977]
[5.57721357 2.04303223]
[4.30594389 1.34787855]]
预测方差 =
[[[0.02950483 0.00558393]
[0.00558393 0.01126496]]
[[0.30034404 0.04402642]
[0.04402642 0.07200287]]
[[0.24667106 0.08489917]
[0.08489917 0.04585074]]]
顺序: [0 2 1]
准确率:97.33%