深度学习【40】Improved Techniques for Training GANs
该论文提出了一些关于训练GAN的技巧,在mnist上生成的样本人类无法分辨真假,在CIFAR-10上生成的样本人类分辨的错误率为21.3%。
优化
feature matching
为G网络加了一个损失函数:
函数f表示D网络最后输出层的前一层特征图。f(x)由真实数据抽取而来,f(G(z))为G网络生成的图片抽取而来。
Minibatch discrimination
GAN训练过程中经常会出现G网络生成的图片为了能够欺骗D网络,而生成仅仅能够让D网络认为是真实的图片。也就是G网络生成的图片都太相似了,没有多样性。这是因为D网络没有一个能够告诉G网络,应该生成不相似的图片。为此作者提出了一个minibatch discrimination来解决这个问题。
minibatch disrimination通过计算一个minibath中样本D网络中某一层特征图之间的差异信息,作为D网络中下一层的额外输出,达到每个样本之间的信息交互目的。具体的,假设样本
上述过程,如图所示:
接着,将
Historical averaging
加入了一个惩罚项,找来找去不知道具体怎么实现的。就不多说了。
One-sided label smoothing
标签平滑,比较操作起来比较简单。训练D网络的时候,生成真实图片的label时将1改成0.9就可以了。
Virtual batch normalization
DCGAN使用了BN,取得了不错的效果。但是BN有个缺点,即BN会时G网络生成一个batch的图片中,每张图片都有相关联(如,一个batch中的图片都有比较多的绿色)。
为了解决这个问题可以使用Reference batch normalization。Reference batch normalization(包含运行网络两次: 第一次是对一个minibatch的参考样本, 这里的参考样本是在训练开始以前被采样并且是保持不变的; 另一个是对当前的minibatch的样本进行训练。 特征的平均值和标准差使用参考样本的batch进行计算。 然后,使用这些统计的信息对两个batch的特征进行标准化处理。 此方法的一个缺点是模型容易对参考batch的样本过拟合。 为了稍微缓解此问题, 作者提出了virtual batch normalization, 对一个样本标准化时使用的统计信息是通过此样本与参考batch的联合来进行计算的。
Semi-supervised learning
作者还加入了半监督学习机制。说起来也很简单,就是在D网络中加入一个图片类别预测(比如imageNet的1000个类别)。损失函数变为:
其中K表示K个类别,
实验结果
在imageNet上与DCGAN的对比
左边的是DCGAN,右边是论文的结果。明显会比DCGAN更好一些。