生成对抗网络 | 原理及训练过程
笔者将在本章给大家介绍生成对抗网络(Generative Adversarial Network)[4]。
生成对抗网络在AI界书写了一个以假乱真的剧本。近年来AI换脸等技术火爆全球,离不开这个网络的点滴贡献。生成对抗网络能够学习数据的分布规律,并创造出类似我们真实世界的物件如图像、文本等。从以假乱真的程度上看,它甚至可以被誉为深度学习中的艺术家。好了,闲言少叙,我们这就走进生成对抗网络(GAN)的世界。
5.1 生成对抗网络的原理
相信大家都会画画,不管画得好坏与否吧,但总归会对着图案勾上两笔。当我们临摹的次数越多,我们画的也就越像。最后,临摹到了极致,我们的画就和临摹的那副画一模一样了,以至于专家也无法分清到底哪幅画是赝品。
好了,在这个例子中,我们将主人公换成生成对抗网络,画画这个操作换成训练,其实也是这么一回事。总体来说,就是这个网络学习数据分布的规律,然后弄出一个和原先数据分布规律的数据。这个数据可以是语音、文字和图像等等。
生成对抗网络Gan网络结构拥有两个部分,一个是生成器(generator),另一个是辨别器(discriminator)。现在我们拿手写数字图片来举个例子。我们希望Gan能临摹出和手写数字图片一样的图,达到以假乱真的程度。生成对抗网络结构图如图 5.1所示。
图 5.1 生成对抗网络
那么它整体的流程如下:
(1) 首先定义一个生成器(generator),输入一组随机噪声向量(最好符合常见的分布,一般的数据分布都呈现常见分布规律),输出为一个图片。
(2) 定义一个辨别器(discriminator),用它来判断图片是否为训练集中的图片,是为真,否为假。
(3) 当辨别器无法分辨真假,即判别概率为0.5时,停止训练。
其中,生成器和辨别器就是我们要搭建的神经网络模型,可以是CNN、RNN或者全连接神经网络等,只要能完成任务即可。
5.2 生成对抗网络的训练过程
(1) 初始化生成器G和辨别器D两个网络的参数。
(2) 从训练集抽取n个样本,以及生成器利用定义的噪声分布生成n个样本。固定生成器G,训练辨别器D,使其尽可能区分真假。
(3) 循环更新k次辨别器D之后,更新1次生成器G,使辨别器尽可能区分不了真假。
多次更新迭代后,理想状态下,最终辨别器D无法区分图片到底是来自真实的训练样本集合,还是来自生成器G生成的样本即可,此时辨别的概率为0.5,完成训练。
5.2.1 模型样本的可视化
论文的作者尝试用这个框架分别对MNIST[5], the Toronto Face Database (TFD)[6], and CIFAR-10[7]. 训练了4个生成对抗网络模型。如图 5.2~图 5.5所示。样本是公平的随机抽签,并非精心挑选。最右边的列显示了模型预测生成的示例。
图 5.2 MNIST数据集
图 5.3 the Toronto Face Database数据集
图 5.4 CIFAR-10 (全连接层网络)
图 5.5 CIFAR-10 (卷积辨别器和反卷积生成器)
参考文献
[4] Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[C]. Advances in neural information processing systems, 2014: 2672-2680.
下一期,我们将讲授
生成对抗网络实验部分
敬请期待~
关注我的微信公众号~不定期更新相关专业知识~
内容 |阿力阿哩哩
编辑 | 阿璃
点个“在看”,作者高产似那啥~