EM算法处理鸢尾花数据实战

今天我们用EM算法对鸢尾花数据进行分类处理,EM算法的原理较为复杂,我会总结之后再发出来。我们先来实战看一下EM算法的强大之处。
EM算法是无监督的分类,而我们的鸢尾花数据是已知类别的,所以我们在处理时直接忽略掉类别之一列,任务三个特征是符合三个独立的高斯分布混合得到,仅仅通过分析特征数据的均值、方差,来判断出这三个类别。

1.首先导入包

import numpy as np
from sklearn.mixture import GaussianMixture
import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import pairwise_distances_argmin

2.导入数据并进行初始化

if __name__ == '__main__':
    path = '8.iris.data'  # 数据文件路径
    data = np.loadtxt(path, dtype=float, delimiter=',', converters={4: iris_type})
    # 将数据的0到3列组成x,第4列得到y
    x_prime, y = np.split(data, (4,), axis=1)
    y = y.ravel()
    #print(x_prime)
    n_components = 3
    feature_pairs = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]
    plt.figure(figsize=(10, 9), facecolor='#FFFFFF')

3.对四个特征两两组合进行预测

 for k, pair in enumerate(feature_pairs):
        x = x_prime[:, pair]
        #print(x)   # y是目标值的列向量 它等于分类0,1,2时的值对应的X的位置 就可以算出每一类的实际均值
        m = np.array([np.mean(x[y == i], axis=0) for i in range(3)])  # 均值的实际值
        print ('实际均值 = \n', m)

        gmm = GaussianMixture(n_components=n_components, covariance_type='full', random_state=0)
        gmm.fit(x)
        print ('预测均值 = \n', gmm.means_)
        print ('预测方差 = \n', gmm.covariances_)
        y_hat = gmm.predict(x)
        order = pairwise_distances_argmin(m, gmm.means_, axis=1, metric='euclidean')
        print ('顺序:\t', order)

这里主要算出实际均值的方法,然后带入高斯混合模型 算出训练后的预测均值和方差
但是由于算出的均值和实际均值的顺序可能不能一一对应 所以要看看顺序 即order
得到顺序为

顺序:	 [0 2 1]

可以看到并不是0.1.2 即预测的第二个均值 实则是我们实际的第三个特征的均值 ,所以我们要做一点小转换

        n_sample = y.size
        n_types = 3
        change = np.empty((n_types, n_sample), dtype=np.bool)
        for i in range(n_types):
            change[i] = y_hat == order[i]
        for i in range(n_types):
            y_hat[change[i]] = i
        acc = u'准确率:%.2f%%' % (100*np.mean(y_hat == y))
        print (acc)

通过此操作可以将顺序变正并且打印出准确率

4.画图

cm_light = mpl.colors.ListedColormap(['#FF8080', '#77E0A0', '#A0A0FF'])
        cm_dark = mpl.colors.ListedColormap(['r', 'g', '#6060FF'])
        x1_min, x1_max = x[:, 0].min(), x[:, 0].max()
        x2_min, x2_max = x[:, 1].min(), x[:, 1].max()
        x1_min, x1_max = expand(x1_min, x1_max)
        x2_min, x2_max = expand(x2_min, x2_max)
        x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]
        grid_test = np.stack((x1.flat, x2.flat), axis=1)
        grid_hat = gmm.predict(grid_test)

        change = np.empty((n_types, grid_hat.size), dtype=np.bool)
        for i in range(n_types):
            change[i] = grid_hat == order[i]
        for i in range(n_types):
            grid_hat[change[i]] = i

        grid_hat = grid_hat.reshape(x1.shape)
        plt.subplot(3, 2, k+1)
        plt.pcolormesh(x1, x2, grid_hat, cmap=cm_light)
        plt.scatter(x[:, 0], x[:, 1], s=30, c=y, marker='o', cmap=cm_dark, edgecolors='k')
        xx = 0.95 * x1_min + 0.05 * x1_max
        yy = 0.1 * x2_min + 0.9 * x2_max
        plt.text(xx, yy, acc, fontsize=14)
        plt.xlim((x1_min, x1_max))
        plt.ylim((x2_min, x2_max))
        plt.xlabel(iris_feature[pair[0]], fontsize=14)
        plt.ylabel(iris_feature[pair[1]], fontsize=14)
        plt.grid()
    plt.tight_layout(2)
    plt.suptitle(u'EM算法无监督分类鸢尾花数据', fontsize=20)
    plt.subplots_adjust(top=0.92)
    plt.show()

之后可以得到打印结果

实际均值 = 
 [[1.464 0.244]
 [4.26  1.326]
 [5.552 2.026]]
预测均值 = 
 [[1.4639995  0.24399977]
 [5.57721357 2.04303223]
 [4.30594389 1.34787855]]
预测方差 = 
 [[[0.02950483 0.00558393]
  [0.00558393 0.01126496]]

 [[0.30034404 0.04402642]
  [0.04402642 0.07200287]]

 [[0.24667106 0.08489917]
  [0.08489917 0.04585074]]]
顺序:	 [0 2 1]
准确率:97.33%

EM算法处理鸢尾花数据实战