Hulu机器学习问题与解答系列 | 二十九:WGANs:抓住低维的幽灵

Hulu机器学习问题与解答系列 | 二十九:WGANs:抓住低维的幽灵

认识了生成式对抗网络(见本系列第25题)后,来看看它的变种WGAN吧~

今天的内容是

【WGANs:抓住低维的幽灵】

场景描述

看过《三体3》的朋友,一定听说过“降维打击”这个概念,像拍苍蝇一样把敌人拍扁。其实,低维不见得一点好处都没有。想象猫和老鼠这部动画的一个镜头,老鼠Jerry被它的劲敌Tom猫一路追赶,突然Jerry发现墙上挂了很多照片,其中一张的背景是海边浴场,沙滩上有密密麻麻的很多人,Jerry一下子跳了进去,混在人群中消失了,Tom怎么也找不到Jerry。 三维的Jerry变成了一个二维的Jerry,躲过了Tom。一个新的问题是:Jerry对于原三维世界来说是否还存在? 极限情况下,如果这张照片足够薄,没有厚度,那么它就在一个二维平面里,不占任何体积,体积为零的东西不就等于没有吗!拓展到高维空间中,这个体积叫测度,无论N维空间的N有多大,在N+1维空间中测度就是零,就像二维平面在三维空间中一样。因此,一个低维空间的物体,在高维空间中忽略不计。对生活在高维世界的人来说,低维空间是那么无足轻重,像一层纱,如一个幽灵,似有似无,是一个隐去的世界。

Hulu机器学习问题与解答系列 | 二十九:WGANs:抓住低维的幽灵

2017年,一个训练生成对抗网络的新方法——WGAN被提出。在此之前,GANs已提出三年,吸引了很多研究者来使用它。原理上,大家都觉得GANs的思路实在太巧妙,理解起来一点都不复杂,很符合人们的直觉,万物不都是在相互制约和对抗中逐渐演化升级吗。理论上,Goodfellow在2014年提出GANs时,已经给出GANs的最优性证明,证明GANs本质上是在最小化生成分布与真实数据分布的Jensen-Shannon Divergence,当算法收敛时生成器刻画的分布就是真实数据的分布。但是,实际使用中发现很多解释不清的问题,生成器的训练会很不稳定。生成器这只Tom猫,很难抓住真实数据分布这只老鼠Jerry。

问题描述

请思考:原GANs中存在哪些问题,会成为制约模型训练效果的瓶颈;WGAN针对这些问题做了哪些改进; WGAN算法的具体步骤;并写出WGAN的伪代码。

知识点:JS距离、坍缩模式、Wasserstein距离、

1-Lipschitz函数

解答与分析

1. GANs的陷阱:请回答原GANs中存在哪些问题,成为了制约模型训练效果的瓶颈。

难度:★★★

GANs的判别器试图区分真实样本和生成的模拟样本。Goodfellow在论文中指出,训练判别器,实际是在度量生成器分布和真实数据分布的Jensen-Shannon Divergence,也称JS距离; 训练生成器,是在减小这个JS距离。这是我们想要的,即使我们不清楚形成真实数据的背后机制,还是可以用一个模拟生成过程去替代之,只要它们的数据分布一致。

但是实验中发现,训练好生成器是一件很困难的事,生成器很不稳定,常出现坍缩模式(Collapse Mode)。什么是坍缩模式?拿图片举例,反复生成一些相近或相同的图片,多样化太差。生成器似乎将图片记下,没有更高级的泛化,更没有造新图的能力,好比一个笨小孩被填鸭灌输了知识,只会死记硬背,没有真正理解,不会活学活用,更无创新能力。

为什么会这样?既然训练生成器基于JS距离,猜测问题根源可能与JS距离有关。高维空间中不是每点都能表达一个样本(如一张图片),空间大部分是多余的,真实数据蜷缩在低维子空间的流形(即高维曲面)上,因为维度低,所占空间体积几乎为零,就像一张极其薄的纸飘在三维空间,不仔细看很难发现。考虑生成器分布与真实数据分布的JS距离,即两个Kullback-Leibler (KL)距离的平均

Hulu机器学习问题与解答系列 | 二十九:WGANs:抓住低维的幽灵

初始的生成器,由于参数随机初始化,与其说是一个样本生成器,不如说是高维空间点的生成器,点广泛分布在高维空间中。打个比方,生成器将一张大网布满整个空间,“兵力”有限,网布得越大,每个点附近的兵力就越少。想象一下,当这张网穿过低维子空间时,所剩的“兵”几乎为零,成了一个“盲区”,如果真实数据全都分布在这,就对生成器“隐身”了,成了“漏网之鱼”。

Hulu机器学习问题与解答系列 | 二十九:WGANs:抓住低维的幽灵

回到公式,看第一个KL距离:

Hulu机器学习问题与解答系列 | 二十九:WGANs:抓住低维的幽灵

高维空间绝大部分见不到真实数据,处处为零,对KL距离的贡献为零;即使在真实数据蜷缩的低维空间,高维空间会忽略低维空间的体积,概率上讲测度为零。KL距离就成了:∫ ㏒2·pr(x)dμ(x)=㏒2

再看第二个KL距离:

Hulu机器学习问题与解答系列 | 二十九:WGANs:抓住低维的幽灵

同理KL距离也为:∫ ㏒2·pg(x)dμ(x)=㏒2。因此,JS距离为㏒2,一个常量。无论生成器怎么“布网”,怎么训练,JS距离不变,对生成器的梯度为零。训练神经网络是基于梯度下降的,用梯度一次次更新模型参数,如果梯度总是零,训练还怎么进行。

2. **陷阱的武器:请回答WGAN针对前面问题做了哪些改进,以及什么是Wasserstein距离。

难度:★★★★

直觉告诉我们:不要让生成器傻傻地在高维空间布网,让它直接到低维空间“抓”真实数据。道理是这样,但是高维空间中藏着无数的低维子空间,怎么找到目标子空间呢?站在大厦顶层,环眺四周,你可以迅速定位远处的山峦和高塔,却很难知晓一个个楼宇间办公室里的事情。你需要线索,而不是简单撒网。处在高维空间,对抗隐秘的低维空间,不能再用粗暴简陋的方法,需要有特殊武器,这就是Wasserstein距离,也称推土机距离(Earth Mover distance):

Hulu机器学习问题与解答系列 | 二十九:WGANs:抓住低维的幽灵

Hulu机器学习问题与解答系列 | 二十九:WGANs:抓住低维的幽灵

怎么理解这个公式?想象你有一个很大的院子,院子里几处坑坑洼洼需要填平,四个墙角都有一堆沙子, 沙子总量正好填平所有坑。搬运沙子很费力,你想知道有没有一种方案,使得花的力气最少。直觉上,每个坑都选择最近的沙堆,搬运的距离最短,但是这里面有个问题,如果最近的沙堆用完了怎么办,或者填完坑后近处还剩好多沙子,或者坑到几个沙堆的距离一样,我们需要设计一个系统的方案,通盘考虑这些问题。最佳方案是上面目标函数的最优解。可以看到,沙子分布给定,坑分布给定,我们关心搬运沙子的整体损耗,而不关心每粒沙子的具体摆放,在损耗不变的情况下,沙子摆放可能有很多选择。对应上面的公式,当你选择一对(x, y)时,表示把x处的一些沙子搬到y处的坑,可能搬部分沙子,也可能搬全部沙子,可能只把坑填一部分,也可能都填满了。

Hulu机器学习问题与解答系列 | 二十九:WGANs:抓住低维的幽灵

为什么Wasserstein距离能克服JS距离解决不了的问题?理论上的解释很复杂,要证明当生成器分布随参数θ变化而连续变化时,生成器分布与真实分布的Wasserstein距离,也随θ变化而连续变化,并且几乎处处可导,而JS距离不保证随θ变化而连续变化。

通俗的解释,接着“布网”的比喻,现在生成器不再“布网”,改成“定位追踪”了,不管真实分布藏在哪个低维子空间里,生成器都能感知它在哪,因为生成器只要将自身分布稍作变化,就会改变它到真实分布的推土机距离,而JS距离是不敏感的,无论生成器怎么变化,JS距离都是一个常数。因此,使用推土机距离,能有效锁定低维子空间中的真实数据分布。

3. WGAN之道:请回答怎样具体应用Wasserstein距离实现WGAN算法。

难度:★★★★★

一群大小老鼠开会,得出结论:如果在猫脖上系一铃铛,每次它靠近时都能被及时发现,那多好!唯一的问题是:谁来系这个铃铛?现在,我们知道了推土机距离这款武器,那么怎么计算这个距离?推土机距离的公式太难求解。幸运的是,它有一个孪生兄弟,和它有相同的值,这就是Wasserstein距离的对偶式:

Hulu机器学习问题与解答系列 | 二十九:WGANs:抓住低维的幽灵

细心的你会发现,这里的fD不同,前者要满足||f||L≤1,即1-Lipschitz函数,后者是一个Sigmoid函数。要求在寻找最优函数时,一定要考虑个“界”,如果没有限制,函数值会无限大或无限小。Sigmoid函数的值有天然的界,而1-Lipschitz不是限制函数值的界,而是限制函数导数的界,使得函数在每点上的变化率不能无限大。神经网络里如何体现1-Lipschitz或K-Lipschitz呢?WGAN的作者思路很巧妙,在一个前向神经网络里,输入经过多次线性变换和非线性**函数得到输出,输出对输入的梯度,绝大部分都是由线性操作所乘的权重矩阵贡献的,因此约束每个权重矩阵的大小,可以约束网络输出对输入的梯度大小。

Hulu机器学习问题与解答系列 | 二十九:WGANs:抓住低维的幽灵

判别器在这里换了一个名字,叫评分器(Critic),目标函数由区分样本来源,变成为样本打分,越像真实样本分数越高,否则越低,有点类似SVM里margin的概念。打个龟兔赛跑的比方,评分器是兔子,生成器是乌龟,评分器的目标是甩掉乌龟,让二者的距离(或margin)越来越大,生成器的目标是追上兔子。严肃一点讲,训练评分器就是计算生成器分布与真实分布的Wasserstein距离;给定评分器,训练生成器就是在缩小这个距离。因此,算法中要计算Wasserstein距离对生成器参数θ的梯度,

Hulu机器学习问题与解答系列 | 二十九:WGANs:抓住低维的幽灵

再通过梯度下降法更新参数,让Wasserstein距离变小。

扩展阅读:

1. Martin Arjovsky, Soumith Chintala, Léon Bottou, Wasserstein GAN, 2017

2. Martin Arjovsky, Léon Bottou, Towards Principled Methods for Training Generative Adversarial Networks, 2017


下一题

【常见的采样方法】

场景描述

对于一个随机变量,我们通常用概率密度函数来刻画该变量的概率分布特性。具体来说,给定随机变量的一个取值,我们可以根据概率密度函数来计算该值对应的概率(密度);反过来,也可以根据概率密度函数提供的概率分布信息来生成随机变量的一个取值,这就是采样。因此,从某种意义上来说,采样是概率密度函数的逆向应用。与根据概率密度函数计算样本点对应的概率值不同,采样过程(即根据概率分布来选取样本点)往往没有那么直接,通常需要依据待采样分布的具体特点来选择合适的采样策略 。 在“采样”章节的前两个小问题中,我们展示了采样的一个具体应用(不均衡样本集的处理),以及针对特定分布(高斯分布)而特别设计的采样方法;接下来,我们来关注一些通用的采样方法和采样策略。

问题描述

抛开那些针对特定分布而精心设计的采样方法外,说一些你所知道的通用采样方法或采样策略,简单描述它们的主要思想以及具体操作步骤。


欢迎留言提问或探讨

关注“Hulu”微信公众号

点击菜单栏“机器学习”获得更多系列文章

Hulu机器学习问题与解答系列 | 二十九:WGANs:抓住低维的幽灵