论文笔记:使用基于Attention的卷积神经网络进行12导联的心电异常的多分类检测

论文地址:Multi-class Arrhythmia detection from 12-lead varied-length ECG using Attention-based Time-Incremental Convolutional Neural Network


一、背景

心电数据往往存在者个体差异和噪音,给心率识别与分析带来很大困难。现有深度学习算法虽多,却没有专门为生理信号设计的模型,生理信号有以下几个特点:1.周期性波动,2.存在异常信号,3.现有12导联的心电数据提供了丰富的信息,如何利用信息的空间分布是个需要考虑的问题。

本文亮点:

  • 提出新颖的ATI-CNN模型,将心电数据处理分为两部分:用CNN捕捉空间信息,RNN捕捉时域信息,并基于Attention机制。
  • 开发循环单元的unwrapping ability来处理不定长的输入信号,不像传统CNN需要实现对信号进行补全和或截断,模型具有较好的鲁棒性。
  • 引入注意力机制,文章实验说明注意力在其中发挥的作用。

本文方法在第一届中国心电挑战赛的数据集上进行实验,对12导联心电数据进行9分类。

二、方法

模型结构如下:

论文笔记:使用基于Attention的卷积神经网络进行12导联的心电异常的多分类检测
论文笔记:使用基于Attention的卷积神经网络进行12导联的心电异常的多分类检测

每个CNN层下都接上BN和ReLU层。
所用Attention机制结构图如下:

论文笔记:使用基于Attention的卷积神经网络进行12导联的心电异常的多分类检测

这里文章没有给出计算公式,个人感觉不太好。
损失函数为:
loss(X,r)=log(exp(p(X,r))jexp(p(X,j))) loss(X, r) = -log(\frac{exp(p(X, r))}{\sum_j exp(p(X, j))})
p(X,j)p(X, j)是模型对输入的XX预测为jj标签的概率,rr为正确标签。

三、实验

3.1 环境

Xeon E5 2650 CPU,128G内存,四块Titan Xp 显卡(我柠檬了)。Ubuntu 16.04版本,Pytorch 0.4.1。

3.2 数据

train_set 1
train_set 2
train_set 3
训练集标签

3.3 预处理

  1. 训练集信号在时间维度上乘以一个随机因子进行压缩或扩张,随机因子服从[1, 1.2]的均匀分布。这个操作会引入噪声,但是能帮助模型达到更好的表现(好奇怪啊这里)。
  2. 对每个训练数据,随机将一部分变为0,片段长度最多1.5秒。
  3. 信号标准化
  4. 所有数据全部补全或截断到60秒(这和前面亮点2不是相矛盾么……)
  5. 随机给样本增加反直觉的错误标签

3.4 实验设定

CNN和Dense层全部使用kaiming initializer进行权重初始化(现在的pytorch,cnn的默认初始化就是kaiming initializer,用tf的朋友可以自定义这个初始化类),lstm使用orthogonal初始化。
Adam优化器,初始学习率为0。0001,每过50个epoch乘以0.1,总计150个epoch。
使用l2正则约束参数,l2损失乘以0.004加入到训练损失上。
Batchsize为128。

四、实验结果

论文笔记:使用基于Attention的卷积神经网络进行12导联的心电异常的多分类检测

五、讨论

ATI-CNN模型分类的混淆矩阵如下:

论文笔记:使用基于Attention的卷积神经网络进行12导联的心电异常的多分类检测

后面的讨论部分文字有点多,有兴趣研究的看原文就好,注重工程和复现的朋友们到这里就可以了。