四张图彻底搞懂CNN反向传播算法(通俗易懂)

作者:陈楠,机器学习算法与自然语言处理
链接:https://zhuanlan.zhihu.com/p/81675803
来源:知乎

阅读本文之前,可以先阅读之前讲述的全连接层的反向传播算法详细推导过程,

已经了解反向传播算法的请自动忽略。

1. 卷积层的反向传播

废话不说,直接上图:

四张图彻底搞懂CNN反向传播算法(通俗易懂)

假设输入为一张单通道图像 四张图彻底搞懂CNN反向传播算法(通俗易懂) ,卷积核大小为 四张图彻底搞懂CNN反向传播算法(通俗易懂) ,输出为 四张图彻底搞懂CNN反向传播算法(通俗易懂) 。为了加速计算,首先将 四张图彻底搞懂CNN反向传播算法(通俗易懂)按卷积核滑动顺序依次展开,如上图所示。其中, 四张图彻底搞懂CNN反向传播算法(通俗易懂) 中的红色框代表 四张图彻底搞懂CNN反向传播算法(通俗易懂) 中的红色框展开后的结果,将 四张图彻底搞懂CNN反向传播算法(通俗易懂) 依次按照此方式展开,可得 四张图彻底搞懂CNN反向传播算法(通俗易懂) 。同理可得 四张图彻底搞懂CNN反向传播算法(通俗易懂) ,然后通过矩阵相乘可得输出 四张图彻底搞懂CNN反向传播算法(通俗易懂) (四张图彻底搞懂CNN反向传播算法(通俗易懂)与 四张图彻底搞懂CNN反向传播算法(通俗易懂) 等价)。此时,已经将CNN转化为FC,与反向传播算法完全一致,这里不再做详细介绍。

当有 N 个样本,做一个batch训练,即channel=N时,前向与反向传播方式如下图所示:

四张图彻底搞懂CNN反向传播算法(通俗易懂)

其中,输入图像channel=3,使用2个 四张图彻底搞懂CNN反向传播算法(通俗易懂) 的卷积核,输出两张图像,如图所示。红色框、黄色框代表的是卷积核以及使用该卷积核得到的输出图像 四张图彻底搞懂CNN反向传播算法(通俗易懂) 。当输入图像为一个batch时, 四张图彻底搞懂CNN反向传播算法(通俗易懂) 的转化方式如上图,首先将输入图像与卷积核分别按单通道图像展开,然后将展开后的矩阵在行方向级联。此时,已经将CNN转化为了FC,与反向传播算法完全一致,这里不再做详细介绍。

2. Average pooling的反向传播

四张图彻底搞懂CNN反向传播算法(通俗易懂)

四张图彻底搞懂CNN反向传播算法(通俗易懂) 不用求,因为 四张图彻底搞懂CNN反向传播算法(通俗易懂) 为常数。 四张图彻底搞懂CNN反向传播算法(通俗易懂)

3. Max-pooling的反向传播

四张图彻底搞懂CNN反向传播算法(通俗易懂)

遍历 四张图彻底搞懂CNN反向传播算法(通俗易懂) 的每一行,找出此行最大值的索引 四张图彻底搞懂CNN反向传播算法(通俗易懂) ,然后将 四张图彻底搞懂CNN反向传播算法(通俗易懂) 中索引为 四张图彻底搞懂CNN反向传播算法(通俗易懂) 的值设为 四张图彻底搞懂CNN反向传播算法(通俗易懂) 对应行的值,将此行其余列的值设为 四张图彻底搞懂CNN反向传播算法(通俗易懂) ,如上图所示红框所示。假设 四张图彻底搞懂CNN反向传播算法(通俗易懂) 中(1,1)处的值是第一行中最大的值,则将 四张图彻底搞懂CNN反向传播算法(通俗易懂) 赋值给 四张图彻底搞懂CNN反向传播算法(通俗易懂) 中索引为 四张图彻底搞懂CNN反向传播算法(通俗易懂) 的位置。最后计算: 四张图彻底搞懂CNN反向传播算法(通俗易懂) 。

欢迎扫码添加小编微信加入机器学习微信群,交流算法,共同进步:
四张图彻底搞懂CNN反向传播算法(通俗易懂)