python3机器学习经典实例-学习笔记-分类算法

文章转载于知乎saya:https://zhuanlan.zhihu.com/p/35261741

可视化混淆矩阵

混淆矩阵是我们用来理解分类模型性能的表格。 这有助于我们理解如何将测试数据分类到不同的类中。 当我们想微调我们的算法时,我们需要了解在做出这些更改之前数据是如何被错误分类的。 有些种类比其他课程更糟糕,混淆矩阵将帮助我们理解这一点。 我们来看看下图:
python3机器学习经典实例-学习笔记-分类算法
在前面的图表中,我们可以看到我们如何将数据分类到不同的类中。 理想情况下,我们希望所有非对角线元素都为0.这表明完美的分类!让我们考虑class 0。总体而言,52个项目实际上属于class 0。如果我们总结第一行中的数字,则得到52。 现在,这些项目中有45项被正确预测,但是分类器说其中4项属于class 1,3项属于class 2。我们可以对其余两行应用相同的分析。值得注意的是,来自class 1的11个项被错误分类为class 0。这构成了该类中约16%的数据点。 这是我们可以用来优化模型的见解。
  • 导入必要的数据库
import numpy as npimport matplotlib.pyplot as pltfrom sklearn.metrics import confusion_matrix

  • 生成数据调用confusion_matrix模块
y_true = [1, 0, 0, 2, 1, 0, 3, 3, 3]y_pred = [1, 1, 0, 2, 1, 0, 1, 3, 3]confusion_mat = confusion_matrix(y_true, y_pred)

  • 定义显示的结
# Show confusion matrixdef plot_confusion_matrix(confusion_mat):    plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.cm.gray)    plt.title('Confusion matrix')    plt.colorbar()    tick_marks = np.arange(4)    plt.xticks(tick_marks, tick_marks)    plt.yticks(tick_marks, tick_marks)    plt.ylabel('True label')    plt.xlabel('Predicted label')    plt.show()

我们使用imshow函数来绘制混淆矩阵。 其他功能都很简单! 我们只需使用相关功能设置标题,颜色条,标记和标签。 tick_marks参数的范围从0到3,因为我们在数据集中有四个不同的标签。 np.arangefunction给了我们这个numpy数组。
  • 进行显示结果


plot_confusion_matrix(confusion_mat)
输出结果:
python3机器学习经典实例-学习笔记-分类算法
对角线的颜色很强烈,我们希望它们的颜色变得深。 浅黄色表示零。 非对角线空间中有几个绿色,表示错误分类。 例如,当真实标签为0时,预测标签为1,如我们在第一行中所看到的。 事实上,所有的错误分类属于第一类,因为第二列包含三个非零的行。 从图中很容易看到这一点。
  • 提取性能报告


# Print classification reportfrom sklearn.metrics import classification_reporttarget_names = ['Class-0', 'Class-1', 'Class-2', 'Class-3']print (classification_report(y_true, y_pred, target_names=target_names))
输出的结果:


precision recall f1-score support Class-0 1.00 0.67 0.80 3 Class-1 0.50 1.00 0.67 2 Class-2 1.00 1.00 1.00 1 Class-3 1.00 0.67 0.80 3avg / total 0.89 0.78 0.79 9
结果分析