迁移学习GAN网络:Generative Adversarial Nets

Generative Adversarial Nets

简述:
目前,较为成功的还是判别模型。在生成模型方面由于概率计算等困难,未获得较大的成功。本文提出的GAN网络不需要马尔科夫链和推断,只需要梯度下降。在GAN 网络中,部分为生成网络 (Generative Network),此部分负责生成尽可能地以假乱真的样本,这部分被成为生成器 (Generator);另一部分为判别网络 (Discriminative Network), 此部分负责判断样本是真实的,还是由生成器生成的,这部分被成为判别器 (Discriminator)。
G在训练过程中的目的是生成尽可能逼真的图片去让判别器判断不了这张图片到底是真实图片还是生成的虚假照片,D在训练过程中的目的就是尽可能取辨别真假图片,所以G是希望是D的犯错率最大化,而D则是希望自己犯错率最小化,二者互为对抗,在竞争*同进步。理论上这种关系可以达到一个平衡点,即所谓的纳什均衡,也就是说G生成的图片D判别它为真实数据的概率是0.5,也即现在判别器已经无法区分生成器所生成的图片的真假,那么生成器的目的也就达到了,以假乱真了。
但是这个网络存在一个缺陷,G不能频繁更新,以保证D能跟上脚步。

问题or相关工作:

  1. GAN网络不需要马尔科夫链和推断,只需要梯度下降,是利用观察法通过生成过程对导数进行反向传播:
    迁移学习GAN网络:Generative Adversarial Nets

  2. 通过把噪音 Pz(z)加入生成器 G(z;θg)从原始数据分布中生成新的数据分布 Pg,判别器D(x)尽力表征原始数据分布,而不是生成器生成的分布Pg。下面是优化目标公式:
    迁移学习GAN网络:Generative Adversarial Nets
    其计算的过程是先计算k步D,再计算一步G。这样可以使得G缓慢改变,由此D就可以跟上步伐一直保持在最优解附近。
    图解如下,其中,生蓝色为判别器的数据分布,黑色为原始数据的分布,绿色为生成器的分布。下面的水平线是采样z的区域,在这种情况下是一致的。上面的水平线是x域的一部分。向上的箭头表示映射x = G(z)如何将非均匀分布pg施加到变换后的样本上。下面的水平线是采样z的区域,在这种情况下是一致的。上面的水平线是x的定义域的一部分。向上的箭头表示映射x = G(z)如何将非均匀分布pg施加到变换后的样本上。G在高密度区域收缩,在Pg低密度区域扩张。
    迁移学习GAN网络:Generative Adversarial Nets
    上图的步骤为:
    2.1.蓝色的判别器每次去寻找到黑色的原始数据分布和绿色的生成器分布的最大区分界面,将它们分开。
    2.2.绿色的生成器每次去逼近黑色的原始分布,以迷惑蓝色的判别器。
    2.3.随着多次的迭代,绿色的生成器分布最终与黑色的原始数据分布重合,生成的样本再也无法被蓝色的判别器给区分开来。
    2.4.迭代停止,此时的生成器已经具有很高的迷惑性。

3. 算法:
迁移学习GAN网络:Generative Adversarial Nets
解析:
3.1在判别器的每一轮迭代中,生成器采样一个minibatch的噪音和一个minibatch的样本数据,然后通过梯度下降来更新判别器的权重。
3.2重复第一步k次。
3.3在生成器的每一轮迭代中,采样一个minibatch的噪音,然后通过梯度下降更新生成器的权重。
3.4不断重复1、2、3步,直到最大迭代次数或权重的更小已小于停止迭代的阈值。

Pg=Pdata
对于固定的G,最优鉴别器D为
迁移学习GAN网络:Generative Adversarial Nets
鉴别器D的训练准则,对于任意发生器G,是使V(G,D)的数量最大化
迁移学习GAN网络:Generative Adversarial Nets
对于任意(a, b)∈R2 {0,0},函数y→a log(y) +b log(1−y)在[0,1]点a\(a+b)处达到最大值。
得出定理一:当且仅当Pg = Pdata时,得到虚拟训练准则C(G)的全局最小值。此时,C(G)的值为- log 4。
定理二:如果G和D有足够的容量,则在算法1的每一步,鉴别器都可以达到给定G的最优值,并更新Pg以改进判据,然后Pg收敛到Pdata。

迁移学习GAN网络:Generative Adversarial Nets
成果:

在手写体识别MNIST 数据集及TFD比赛中不同模型的测试结果。
迁移学习GAN网络:Generative Adversarial Nets
模型生成样本的可视化:最右边的一列显示了相邻样本的最近的训练实例,以证明模型没有记住训练集。样本是公平随机抽取的,不是精选的。
迁移学习GAN网络:Generative Adversarial Nets
在整个模型的z空间坐标之间进行线性插值得到的数字。
迁移学习GAN网络:Generative Adversarial Nets