【paper reading】Interpret Neural Networks by Identifying Critical Data Routing Paths

【paper reading】Interpret Neural Networks by Identifying Critical Data Routing Paths


现在深度学习依靠大数据加上现在比较充足的计算能力,神经网络十分火热,也在很多方面有很好的应用。现在cvpr之类的*会议很多论文都是基于神经网络的研究。但现在有一个问题,就是神经网络究竟是怎么工作的,它运行的机理到底是什么样的,大家其实不是特别清楚。

对于深度学习模型的可解释性,也就是我们想要知道它内部的工作机理是什么样的,为什么会发生错误,什么情况下会发生错误,这方面的研究也是很受关注。包括美国的一些相关科研机构都把可解释性的机器学习作为深度学习下一代的比较有代表性的模型。

1.简介

这篇论文是cvpr2018上的一篇论文,是清华大学苏航老师的工作,是关于神经网络的可解释性研究。举一个简单的例子,现在用机器学习来给人看病在技术上没有太大的问题,但是在诊断出你这个病以后,它不像医生那样,不能给出依据,这就让人觉得不是很靠谱。
【paper reading】Interpret Neural Networks by Identifying Critical Data Routing Paths
同样,由于它是黑箱,我们输入一个数据,它最后判断错误了,我们也不知道它为什么会做出这样的判断。例如下面的例子中,将阿尔卑斯山的照片加上一些随机产生的噪音,最后网络以99.99%的置信度将它判断成是狗,将河豚的照片加上一些随机产生的照片,最后网络以100%的置信度将它判断成是螃蟹。明明加上噪音之前神经网络的判断都是没有问题的,问什么加上这些对人眼来讲没有太大干扰的噪声会对神经网络的输出产生这么大的影响,能搞懂这一点能让我们更加理解神经网络的工作原理,从而在后面能够更好的改进。
【paper reading】Interpret Neural Networks by Identifying Critical Data Routing Paths
对于机器学习的可解释性大体可以从三个方面去理解:

第一个就是数据的可解释性。我们想要知道数据上究竟是哪些维度上面的信息对任务起到了作用。
第二个角度是预测的可解释性。我们想要知道为什么模型会把输入的数据分到这个类别。
第三个角度就是模型的可解释性。模型内部如何达到这样的结果,它的每一部分到底在干什么。

从本质上来讲现在所有的机器学习模型,特别是现在以统计方法为主的机器学习模型,我们都把数据映射到特征空间中,然后后面的工作都是在特征空间中进行的。这里的关键问题是人其实不理解这里的特征空间的。
我们人能理解的就是原始的数据空间,另外我们人还有一个高层的语义空间。
【paper reading】Interpret Neural Networks by Identifying Critical Data Routing Paths
要是现在的机器学习模型能被人理解,第一种方式就是将特征空间把它追溯到原始的数据空间,一种方式是把特征空间和更高级的语义空间想联系。这也是苏航老师在17年和18年的两个工作。他认为像图像识别这方面的内容。比如我们看到一只猫,马上认出这是一只猫,这种就比较适用于将特征空间追溯到原始的数据空间。而涉及到一些更加高级的逻辑推理,比如图像或者视频描述任务,就更适用于将特种空间与高层的语义空间想联系。

2.神经网络的可解释性

作者的思想比较简单,也就是说一张图像输进神经网络之后,它通过了不同的层可以看做是信息在一层一层往下流动。我们希望知道最后这个决策到底是哪个信息流对他起到了最关键的作用。我们就想找到这个样本在神经网络中它的数据流究竟激发了哪些神经元,或者说哪些神经元最终对这个结果产生了影响。这些重要的神经元就是critical neural node,把他们连接生成的路径就叫做critical routing path。
【paper reading】Interpret Neural Networks by Identifying Critical Data Routing Paths
他是在VGG-net,Alex-net和Resnet上都进行了解释性研究,它是怎么来做的呢,就是首先以VGG-net为例,它直接把训练好的VGG-net原封不动先拿过来。然后在每一层的每个通道上增加了一个control gate,就是一个lambda的标量参数,每一个通道的输出都需要乘上这个lambda参数才能作为最终的结果。所以可以看出,所有的lambda都等于1的时候,就是原始的网络。
【paper reading】Interpret Neural Networks by Identifying Critical Data Routing Paths
我们最终的目标就是要学习一组lambda权重。这个权重要满足什么要求呢,就是首先它要非负,从控制门的功能定义来看,λ只应抑制或放大输出通道**。λ的负值会否定网络中的原始输出**。第二个要求就是λ应该是稀疏的并且大部分接近于零,他认为整个神经网络中其实只有极少数的结点真正的在整个预测过程中起到比较重要的作用。

最后工作的优化目标就是:
minΛL(fθ(x),fθ(x;Λ))+γkλk1\min_\Lambda \mathcal L(f_\theta(x),f_\theta(x;\Lambda))+\gamma\sum_k|\lambda_k|_1
s.t.   λk0,k=1,2,...,Ks.t.\ \ \ \lambda_k \succeq 0, k = 1,2,...,K
其中Λ={λ1,λ2,...,λK}\Lambda=\{\lambda_1,\lambda_2,...,\lambda_K\},K是网络的层数。
前面这个loss就是一个交叉熵的loss,前面这个就是原始的网络,后面这个加入λ\lambda参数的就是加了控制门的网络。然后我们希望整个λ\lambda是受L1L_1范数约束的,是稀疏的。

整个求解过程就是一个梯度下降的求解过程,迭代求解。
【paper reading】Interpret Neural Networks by Identifying Critical Data Routing Paths

3.实验部分

最后是一个是实验结果,主要关注于卷积层的输出
【paper reading】Interpret Neural Networks by Identifying Critical Data Routing Paths
可以看到在VGGnet上,在保证网络精度不下降的情况下,将一些nueral去除。第一行是原始网络,100%。看最后一行可以发现将网络去除掉大部分,只保留13.5%的部分,就能达到原始的精度。这还是比较惊人的。就是大部分的结构其实并没有起到很大的作用。

刚才选出了13.5%的node。下面这个实验结果就证明了这个node确实是很critical的node。
【paper reading】Interpret Neural Networks by Identifying Critical Data Routing Paths
可以看到左边的这个top mode。就是把选出来的节点按阈值降序排序,一开始仅仅去掉阈值最大的1%的节点,可以看到整个网络的精度直接有一个很大的降低。而去除那些阈值低的节点变化却十分平缓。

下一步,进一步把网络进行了一个可视化的操作,在最初的基层,数据基本都混合在一起,随着层数越来越深,不同标签的数据开始逐渐分离开来。
【paper reading】Interpret Neural Networks by Identifying Critical Data Routing Paths
【paper reading】Interpret Neural Networks by Identifying Critical Data Routing Paths