WGAN-GP方法介绍

原文标题:Improved Training of Wasserstein GANs

原文链接[1704.00028] Improved Training of Wasserstein GANs

背景介绍

训练不稳定是GAN常见的一个问题。虽然WGAN在稳定训练方面有了比较好的进步,但是有时也只能生成较差的样本,并且有时候也比较难收敛。原因在于:WGAN采用了权重修剪(weight clipping)策略来强行满足critic上的Lipschitz约束,这将导致训练过程产生一些不希望的行为。本文提出了另一种截断修剪的策略-gradient penalty,即惩罚critic相对于其输入(由随机噪声z生成的图片,即fake image)的梯度的norm。就是这么一个简单的改进,能使WGAN的训练变得更加稳定,并且取得更高质量的生成效果。

注意:GAN之前的D网络都叫discriminator,但是由于这里不是做分类任务,WGAN作者觉得叫discriminator不太合适,于是将其叫为critic。

方法介绍

介绍WGAN-GP方法前,先简单介绍一下WGAN,WGAN的损失函数如下:

WGAN-GP方法介绍

公式1

这里需要注意的是,WGAN的提出是作者分析了一堆统计度量(KL散度,JS散度,TV距离,W距离等)后,得出Wasserstein距离(下简称W距离)最适合GAN的训练。按理说WGAN的损失函数就是一个分布到另一个分布的W距离,如公式2所示:

WGAN-GP方法介绍

公式2

但是大家可以看到,公式2中有个下确界符号inf,让人看着有点懵。不过没关系,Kantorovich-Rubinstein duality理论(该理论太复杂,这里不介绍)告诉我们:当critic满足Lipschitz连续条件时,公式2可以转化为公式1的形式。直观上,公式2跟神经网络半毛钱关系没有,给出这个公式我们也不会优化,但是公式1一看就是一个标准的神经网络的损失函数(把x和z当做输入,G和W当做网络的两部分)。公式2转化为公式1的形式后,就可以用神经网络中常用的梯度下降法去优化了。

注意这里有个名词:“Lipschitz连续”,大家不要被这个牛逼的名字吓到,其概念其实很简单,意思就是定义域内每点的梯度恒定不超过某个常数(常数是多少无所谓,不是无穷就行)。那么怎么来保证critic的Lipschiz连续呢?作者用的方法极其简单,就是weight clip策略。weight clip策略的意思是:限制神经网络 WGAN-GP方法介绍 的所有参数w不超过某个范围[-c, c](比如[-0.01, 0.01]),即大于c的置为c,小于-c的置为-c。为什么这样做能保证Lipschiz连续(定义域内每点的梯度不超过某个常数)呢?因为critic相对于其输入的导数是个含w的表达式,w不超过某个范围,那critic相对于其输入的梯度一定也不会超过某个范围,Lipschiz连续条件得以满足。

这么粗暴的做法WGAN作者也是觉得不妥的,但是暂时没有想到更好的办法,只能用这个简单的方法了。WGAN-GP就是从这点入手做文章。

其实,WGAN-GP方法的作者也是普通人,一开始想到的也是很普通的方法,比如把weight clipping这么粗暴的方法改为L2 norm clip,做权重的归一化等。然并卵,这些方法的效果跟带weight clipping的WGAN效果没啥区别。作者也尝试了batch normalization的方法,但是发现当critic太深时,WGAN难以收敛。于是,才有了WGAN-GP方法。WGAN-GP的目标函数如下所示:

WGAN-GP方法介绍

公式3

可以看到,WGAN-GP相对于WGAN的改进很小,除了增加了一个正则项,其他部分都和WGAN一样。 这个正则项就是WGAN-GP中GP(gradient penalty),即梯度约束。这个约束的意思是:critic相对于原始输入的梯度的L2范数要约束在1附近(双边约束)。为什么这个约束是合理的,这里作者给了一个命题,并且在文章补充材料中给出了证明,这个证明大家有兴趣可以自己去看,这里只想简单介绍一下这个命题。这个命题说的是在最优的优化路径上(把生成分布推向真实分布的“道路”上),critic函数对其输入的梯度值恒定为1。有了这个知识后,我们可以像搞传统机器学习一样,将这个知识加入到目标函数中,以学习到更好的模型。

这里需要说明一下,WGAN-GP作者加的这个约束能保证critic也是一个Lipschiz连续函数。因为critic对任意输入x的梯度都是一个含参数w的表达式,而这个梯度的L2 norm大小约束在1附近,那w也不超过某个常数。因而从保证Lipschiz连续的条件上,GP的作用跟weight clip是一样的。

WGAN-GP具体算法步骤如下:

WGAN-GP方法介绍

可以看出跟WGAN不同的主要有几处:1)用gradient penalty取代weight clipping;2)在生成图像上增加高斯噪声;3)优化器用Adam取代RMSProp。

这里需要注意的是,这个GP的引入,跟一般GAN、WGAN中通常需要加的Batch Normalization会起冲突。因为这个GP要求critic的一个输入对应一个输出,但是BN会将一个批次中的样本进行归一化,BN是一批输入对应一批输出,因而用BN后无法正确求出critic对于每个输入样本的梯度。