Datawhale《深度学习-NLP》Task1-NLP-召回率、准确率、ROC曲线、AUC、PR曲线学习理解
1.下载数据
中文数据集:THUCNews THUCNews数据子集:https://pan.baidu.com/s/1hugrfRu 密码:qfud
2.基于CNN神经网络进行分类
https://github.com/gaussic/text-classification-cnn-rnn GITHUB地址
下载完成之后导入pycharm进行模型训练,此代码运行时候需要在Edit-Configurations Parametes中设置train or test参数才能启动不然会报错。
代码处理数据流程大致分为:
2.1.判断是训练模型还是测试模型。
2.2.获取CNN模型参数。
2.3.创建词汇表,没有的话会根据训练数据生成。此作用就是把每个词转换成ID用。
2.4.获取文本分类的类目,并把文本分类转成ID。
2.5.定义CNN模型,然后开始训练模型。
2.6.训练模型中数据也进行处理,分别将词变成数字ID,相当于每个词的序号并用keras进行长度控制到600范围,没有用到w2c模型进行处理,后续可以尝试。对于分类目录则是进行的one-hot编码。
2.7.共训练10次,每次训练如果高于上一次模型保存的正确率,则会保存概率高的模型。如果每次训练次数大于1000,模型正确率没有提高则会提前结束训练。
3. 运行完成之后会生成此目录和模型
4.模型测试
将运行参数改为test,我的模型结果如下:
Testing...
Test Loss: 0.99, Test Acc: 73.23%
Precision, Recall and F1-Score...
precision recall f1-score support
体育 0.89 0.99 0.93 1000
财经 0.80 0.97 0.88 1000
房产 0.68 0.93 0.79 1000
家居 0.48 0.07 0.13 1000
教育 0.53 0.77 0.63 1000
科技 0.82 0.42 0.55 1000
时尚 0.75 0.82 0.78 1000
时政 0.57 0.56 0.56 1000
游戏 0.90 0.84 0.87 1000
娱乐 0.80 0.95 0.87 1000
avg / total 0.72 0.73 0.70 10000
Confusion Matrix...
[[986 1 0 2 6 0 0 4 0 1]
[ 3 975 5 0 6 0 0 8 3 0]
[ 6 3 928 11 26 0 1 7 12 6]
[ 41 112 279 75 184 15 44 199 16 35]
[ 15 26 38 19 768 20 36 59 11 8]
[ 8 2 18 6 144 418 166 103 23 112]
[ 16 5 9 10 41 39 819 12 12 37]
[ 20 74 58 20 226 5 2 563 10 22]
[ 16 16 17 6 28 10 14 39 836 18]
[ 2 0 4 8 8 4 13 0 6 955]]
Time usage: 0:00:22
输出下面结果则CNN模型训练成功。
下面回归到正题这里都可以用sklean里面的方法《召回率、准确率、ROC曲线、AUC、PR曲线》
from sklearn.metrics
1.召回率:也就是模型输入结果里面的recall。算0和1的概率
recall = recall_score(y_test, y_predict)
#recall得到的是一个list,是每一类的召回率
2.准确率:所有样本中被预测正确的样本的比率
分类模型总体判断的准确率(包括了所有class的总体准确率)
accuracy = accuracy_score(y_test, y_predict)
3.ROC曲线,AUC ,PR曲线
总结一下,对于计算ROC,最重要的三个概念就是TPR, FPR, 截断点。
ROC曲线
ROC曲线越接近左上角,代表模型越好,即AUC接近1
from sklearn.metrics import roc_auc_score, auc
import matplotlib.pyplot as plt
y_predict = model.predict(x_test)
y_probs = model.predict_proba(x_test) #模型的预测得分
fpr, tpr, thresholds = metrics.roc_curve(y_test,y_probs)
roc_auc = auc(fpr, tpr) #auc为Roc曲线下的面积
#开始画ROC曲线
plt.plot(fpr, tpr, 'b',label='AUC = %0.2f'% roc_auc)
plt.legend(loc='lower right')
plt.plot([0,1],[0,1],'r--')
plt.xlim([-0.1,1.1])
plt.ylim([-0.1,1.1])
plt.xlabel('False Positive Rate') #横坐标是fpr
plt.ylabel('True Positive Rate') #纵坐标是tpr
plt.title('Receiver operating characteristic example')
plt.show()
参考:https://www.jianshu.com/p/5df19746daf9