深度学习-- > Improved GAN-- > f-GAN

上一篇博文中详细总结和推导了GAN网络的原理,但是如此的GAN网络有他的不足之处,本博文将详细说明其不足之处,以及解决和提高的办法。

original GAN 不足之处

简单回顾GAN网络原理

深度学习-- > Improved GAN-- > f-GAN

蓝色的线表示:Generated distribution
绿色的线表示:Data(target) distribution
红色的线表示:Discriminator

在上图中的左上第一个子图中,generator 生成的数据分布与Data distribution 相差较大,则 Discriminator 也即是D(x)Generated distribution 以较低的概率,而给 Data(target) distribution 以较高的概率,由此得到 D(X) 的曲线走向。在左二子图中,更新后的 generator 可能会因为更新步伐太大,移到了Data distribution 的右边,由此 D(X) 更新如图,GD 如此不断的更新迭代,最终 Generated distributionData(target) distribution 重合,那么此时D(X)就变成了一条水平直线。

存在的问题

我们知道整个GAN网络的目标都是在:
深度学习-- > Improved GAN-- > f-GAN

通过不断的更新DG来得到比较好的 Generator ,也就是上式的G,那么在更新D时:

max V(G,D) = 2log2 + 2JSD(Pdata(X)||PG(X))

我们是不断的通过Minize max V(G,D) 来更新 G,那么问题来了,这个max V(G,D)是否能准确的反映PdataPG之间的差距呢?

深度学习-- > Improved GAN-- > f-GAN

由上图可以看出,当PGPdata无重合时(可能是sample出的样本没有重合),即使两者的 distribution 在改进,其JS(PG||Pdata)始终为log2,那么在更新G参数时,没有改进的动力。很难得到很好的Generator

Unified Framework

f-divergence

之前我们介绍的GAN网络中的Discriminator只是和JensenShannon divergence有关,论文Training Generative Neural Samplers using Variational Divergence Minimization中介绍了fdivergence,其Discriminator不只是仅仅由JensenShannon divergence来定义,其核心的一句话就是you can use any fdivergence

我们假设有两个分布,分别 pq,代入到GAN网络中,就是之前说的 PdataPG,其中 p(x)q(x) 就是sample 出来的样本的概率。由此我们可以这样来定义fdivergence

Df(P||Q)=xq(x) f(p(x)q(x))dx

显然这样定义的Df(P||Q)必须能起到衡量PQ 分布的拟合程度,并且值越小拟合的越好。那么就必须具备以下条件:

  • f函数必须是凸的
  • f(1)=0

那么可以得到,当对于所有的x都有P(x)==Q(x)时:Df(P||Q)=0,这个时候显然拟合的最好,并且是smallest Df(P||Q)

再由凸函数的特性可得:

Df(P||Q)=xq(x) f(p(x)q(x))dxf(xq(x) p(x)q(x)dx)=fxp(x)=f(1)

故可得到 Df(P||Q)f(1)

其实KL divergence就可以理解为一种fdivergence。那么f可以选哪些函数呢?只要符合上面的要求即可:

深度学习-- > Improved GAN-- > f-GAN

Fenchel Conjugate

首先假设f(x)是一个凸函数,定义如下公式:

f(t)=maxxdom(f){xtf(x)}

得到 f(t),这里固定住不同的 x(x1,x2,..),都能得到不同的关于t的线性函数,其图可以如下:

深度学习-- > Improved GAN-- > f-GAN

然后取其max,就得到上图红色的那条线。由此可以得到一个结论:

f(x)f(t)

我们把这样的 f(t) 叫做f(x)conjugate function

举个具体的例子,当 f(x)=xlogx 时,可得f(t)

深度学习-- > Improved GAN-- > f-GAN

那么如何得出当 f(x)=xlogx 时,f(t)的具体数学公式呢?

深度学习-- > Improved GAN-- > f-GAN

那么可得结论,当f(x)=xlogx时,其conjugate function f(t)=exp(t1),也即:

f(x)=xlogxf(t)=exp(t1)

这里需要注意:(f)=f

Connect to GAN

那么上面讲的与GAN有什么关系呢?
假设 f(x)f(1)=0,则由上面的推导我们可以得出:

f(t)=maxxdom(f){xtf(x)}f(x)=maxtdom(f){txf(t)}

f(x)=maxtdom(f){txf(t)} 中,我们可以令 x=p(x)q(x),再由上面已经得出的fdivergence 条件可得:

Df(P||Q)=xq(x) f(p(x)q(x))dx=xq(x)(maxtdom(f){p(x)q(x)tf(t)})dx

这里可以假设存在某一个函数D,其输入为x,输出为t,则有:

深度学习-- > Improved GAN-- > f-GAN

注意:不论函数D为何函数,其符号都为大于等于。

那么我们可以选择到某个函数D,使其上式右边取最大,则可得如下:

Df(P||Q)maxDxP(x)D(x)dxxq(x)f(D(x))dx

Df(P||Q) 表示 fdivergence,上面我们已经说明了 fdivergence 可以用来衡量两种分布的拟合程度。

继续推导可得:

深度学习-- > Improved GAN-- > f-GAN

得出的形式是不是很像上一博文中介绍的V(G,D)函数?

V = ExPdata[logD(x)] + ExPG[log(1D(x))]

继续可得:

深度学习-- > Improved GAN-- > f-GAN

所以我们可以这样理解更新G的过程,实际就是不断的减小fdivergence,而这个时候fdivergence 直接就是用来衡量两种分布的拟合程度。

实际train的不同

深度学习-- > Improved GAN-- > f-GAN

original GAN中,在Inner loop中通过多次循环来更新D,然后再更新G;而在上面介绍的 fGan 中,只需要一步即可更新DG

由此我们可以选中任意一种fdiveragenceminize