Generative Adversarial Nets

Generative Adversarial Nets

摘要

生成器和判别器在训练过程中,相互对抗,共同进化。

我们同时训练两个模型,生成模型G,用来捕获数据分布,一个判别模型D用来评估输入样本来自于训练数据(真实数据)的概率。对G来说训练的过程是最大化D犯错误的概率。这个框架相当于一个极小化极大的双方博弈游戏。对于任意函数G, D,在空间中存在一个唯一解,即为G恢复训练数据,D处处等于1/2。G,D都由多层感知机定义,整个系统可以由反向传播训练,所以不需要马尔科夫链和近似推理网络。实验证明可以定性和定量生成样本。

研究背景

最成功的模型之一是判别式模型,通常他们将高维丰富的输入映射到类标签上。然而生成模型的发展并不乐观,由于存在最大似然估计和相关策略上难以解决的概率计算上的困难, 由于难以利用分段线性单元在生成上下文中的好处,深度生成模型的影响很小。

  • 生成式模型—>联合概率分布—>生成具有和训练样本分布一样的样本
  • 感知型模型—>条件概率分布—>训练分类样本

在提到的对抗网络框架中,生成模型对抗着一个对手:一个学习去判别一个样本是来自模型分布还是数据分布的判别模型。生成模型可以被认为是一个伪造团队,试图产生假货并在不被发现的情况下使用它,而判别模型类似于警察,试图检测假币。在这个游戏中的竞争驱使两个团队改进他们的方法,直到真假难分为止。

Generative Adversarial Nets

对抗网络

作者用MLP作为基本网络层,搭建生成器和判别器:

  • 生成器G: 为了学习生成器关于数据x上的概率分布pgp_g, 我们定义了一个先验噪声变量pz(z)p_z(z),然后使用G(z;θg)G(z;\theta_g)来代表数据空间上的映射,G是一个带有参数θg\theta_g的多层感知机代表的可微函数。
    KaTeX parse error: No such environment: equation at position 8: \begin{̲e̲q̲u̲a̲t̲i̲o̲n̲}̲ G = arg \min_G…
    G的目标:将任意分布的数据输入和生成数据的分布近似。

  • 判别器D: 多层感知机D(x,θd)D(x, \theta_d)输出一个标量。D(x)代表x来自于真实数据分布而不是生成数据概率。
    When G is given,
    KaTeX parse error: No such environment: equation at position 8: \begin{̲e̲q̲u̲a̲t̲i̲o̲n̲}̲ \max_D V(D, G)…

Generative Adversarial Nets

在实践中,方程(1)可能不能为G提供足够的梯度。当训练开始时,由于*生几乎随机数据,D以高置信度拒绝G生成的数据,log(1D(G(z)))log (1- D(G(z)))饱和,所以我们可以训练G去最大化log(D(G(z)))log (D(G(z)))而不是最小化log(1D(G(z)))log (1- D(G(z)))。该目标函数使G和D的动力学稳定点相同,并且在训练初期,该目标函数可以提供更强大的梯度。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-s3xzGOCO-1593677177162)(https://raw.githubusercontent.com/fly-dragon211/image/master/img20200702095220.png)]

图1. 判别分布(蓝色曲线,D)给出的来自pxp_x(黑色点线, 数据生成分布)的样本,和生成分布pgp_{g} (G,绿色实线)的样本。下方两个水平线是x, z的分布,箭头代表映射x=G(z)x=G(z)。(a)为训练初期,(b)为D收敛到D(x)=pdata(x)pdata(x)+pg(x)D^*({x}) = \frac{ p_\text{data}({x}) }{ p_\text{data}({x}) + p_g({x})} , © G更新一次后,D的梯度会把G(z)到引导更可能生成被分类成数据的区域。 (d) 若干次训练后,G, D都有了足够的性能。

下一部分是理论研究,基本上表明基于训练准则可以恢复数据生成分布,因为G和D被给予足够的容量,即在非参数极限。我们必须使用迭代数值方法,在训练的内部循环中优化D到完成是禁止的,并且在有限的数据集里会导致过拟合。相反,我们可以在优化D的k个步骤和优化D的一个步骤之间交替。只要G变化的足够慢,D会保持在最优解附近。

算法1:生成对抗网络的minibatch随机梯度下降训练。判别器的训练步数,k,是一个超参数。在我们的试验中使用k=1,使消耗最小。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MIX7mi2f-1593677177188)(https://raw.githubusercontent.com/fly-dragon211/image/master/img/20200702113946.png)]

理论结果

本节的结果是在非参数设置下完成的,例如,我们通过研究概率密度函数空间中的收敛来表示具有无限容量的模型。

我们将在4.1节中显示,这个极小化极大问题的全局最优解为pg=pdatapg=pdata。我们将在4.2节中展示使用算法11来优化等式1,从而获得期望的结果。

4.1 全局最优 pg=pdatap_g = p_{data}

首先考虑固定G的情况下最优化判别器D

假设1:G固定时,最优判别器为

$$
\begin{equation}

D^*G({x}) = \frac{p\text{data}({x})}{p_\text{data}({x}) + p_g({x})}
\end{equation}
$$
这个假设通过V(G,D)V(G, D)公式证明

定理1:当且仅当pg=pdatap_g=p_{data}时,C(G)达到全局最小。此时,C(G)的值为−log4

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-uSOcTQMt-1593742989265)(https://raw.githubusercontent.com/fly-dragon211/image/master/img/20200702122153.png)]]

4.2 算法1的收敛性

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4xW4L5ny-1593743235284)(https://raw.githubusercontent.com/caojx-git/learn/master/notes/images/mysql/mysql-sql-1.png)]

Generative Adversarial Nets
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-QDvxLjVV-1593743373782)(https://raw.githubusercontent.com/caojx-git/learn/master/notes/images/mysql/mysql-sql-1.png)]

Generative Adversarial Nets