GAN网络入门(一)

**

从一个形象的例子讲起

**
GAN网络入门(一)
在上图中,坏人能把普通的一张白纸制成假钞,警察可以分辨出假钞和真钞。在一轮一轮的较量中,坏人制假钞的水平不断提高,而警察面对越来越难以辨认的假钞也不得不提高自己的辨别能力。
套用GAN网络相关术语,我们可以这样来讲这个故事:生成模型(Generative Model,坏人)可以将一个输入噪音(白纸)生成和真实数据差不多的数据(假钞),判别模型(Discriminative Model,警察)能够判断出真实数据(真钱)和类真实数据(假钞)。在一轮又一轮的博弈中,生成模型(Generative Model,坏人)能够输出非常接近真实数据的数据。
GAN网络的目标是使得生成的数据和真实数据更接近。为了达到这个目标,一方面,我们要求G(x)(生成模型网络)能够学习到一组很好的模型参数,使得D(x)(判别模型网络)判别不出来真实数据和类真实数据的区别,另一方面,我们要求D(x)(判别模型网络)的判别能力很强,能够完成对数据的真实性做出很好的二分类任务。

数学解释

GAN网络入门(一)
如上图所示,x代表真实数据,z代表噪音,G(z)代表一个输入噪音通过生成网络后的输出。一方面,我们希望判别网络能够准确判断出数据的真实性,即D(x)尽可能接近1,D(G(z))尽可能接近于0;另一方面,我们希望生成网络产生的数据非常接近真实数据,即D(G(z))尽可能接近于1。
损失函数
GAN网络入门(一)
可以这样理解:损失函数做的是最大化D的区分度,最小化G输出和真实数据的区别。
损失函数可以拆分为两部分:
判别模型:log(D1(x))+log(1-D2(G(z)))…(1)
生成模型:log(D2(G(z)))
…(2)
当判别模型能力强时,D1(x)->1, D2(G(z))->0,(1)式趋近于0
当生成模型能力强时,D2(G(z))->1,(2)式趋近于0

一个简单的GAN案例
GAN网络入门(一)
实现功能:上图中,蓝色线代表真实数据,绿色线代表生成网络输出数据,最终我们希望绿色线能和蓝色线能够比较接近
详细代码
一 定义模型

def main(args):
    model = GAN(
        DataDistribution(),                            //真实数据真实分布
        GeneratorDistribution(range=8),      //生成数据分布
        args.num_steps,                               //迭代次数,取1200
        args.batch_size,                               //一次迭代更新12个点的数据
        args.log_every,                                 //每隔多少次迭代打印一次loss值
    )
    model.train()

二 参数初始化

def __init__(self, data, gen, num_steps, batch_size, log_every):
        self.data = data
        self.gen = gen
        self.num_steps = num_steps
        self.batch_size = batch_size
        self.log_every = log_every
        self.mlp_hidden_size = 4                    //神经网络中隐藏层个数设置为4
        self.learning_rate = 0.03                           //学习率
        self._create_model()
 

三 D_pre网络
判别网络D的参数不能随机初始化,必须具有一定的判别能力,所以预先使用D_pre网络对D的参数进行训练。

    with tf.variable_scope('D_pre'):
        self.pre_input = tf.placeholder(tf.float32, shape=(self.batch_size, 1))
        self.pre_labels = tf.placeholder(tf.float32, shape=(self.batch_size, 1))
        D_pre = discriminator(self.pre_input, self.mlp_hidden_size)
        self.pre_loss = tf.reduce_mean(tf.square(D_pre - self.pre_labels))
        self.pre_opt = optimizer(self.pre_loss, None, self.learning_rate)

四 定义G,D网络
G网络

        with tf.variable_scope('Gen'):
            self.z = tf.placeholder(tf.float32, shape=(self.batch_size, 1))    //输入,为一个随机输入
            self.G = generator(self.z, self.mlp_hidden_size)            //产生输出数据

**D网络** 

D网络有两个输入,一个是真实数据x,另一个是生成网络的输出数据G(z),所以定义两个:D1和D2。

        with tf.variable_scope('Disc') as scope:
            self.x = tf.placeholder(tf.float32, shape=(self.batch_size, 1))
            self.D1 = discriminator(self.x, self.mlp_hidden_size)
            scope.reuse_variables()
            self.D2 = discriminator(self.G, self.mlp_hidden_size)

五 损失函数的定义(非常重要)

self.loss_d = tf.reduce_mean(-tf.log(self.D1) - tf.log(1 - self.D2))    //见(1)式,对(1)式取反
self.loss_g = tf.reduce_mean(-tf.log(self.D2))                               //见(2)式
self.opt_d = optimizer(self.loss_d, self.d_params, self.learning_rate)    //使用优化器对两者的损失函数进行优化
self.opt_g = optimizer(self.loss_g, self.g_params, self.learning_rate)

六 训练模型

 def train(self):
        with tf.Session() as session:
            tf.global_variables_initializer().run()

            # pretraining discriminator
            num_pretrain_steps = 1000
            for step in range(num_pretrain_steps):
                d = (np.random.random(self.batch_size) - 0.5) * 10.0
                labels = norm.pdf(d, loc=self.data.mu, scale=self.data.sigma)
                pretrain_loss, _ = session.run([self.pre_loss, self.pre_opt], {
                    self.pre_input: np.reshape(d, (self.batch_size, 1)),
                    self.pre_labels: np.reshape(labels, (self.batch_size, 1))
                })
            self.weightsD = session.run(self.d_pre_params)     //将d_pre网络的训练结果赋值
            # copy weights from pre-training over to new D network
            for i, v in enumerate(self.d_params):         //将d_pre网络的训练结果拷贝给self.d_params
                session.run(v.assign(self.weightsD[i]))

            for step in range(self.num_steps):
                # update discriminator
                x = self.data.sample(self.batch_size)                 //定义x 真实数据
                z = self.gen.sample(self.batch_size)                  //定义z 噪音输入
                loss_d, _ = session.run([self.loss_d, self.opt_d], {
                    self.x: np.reshape(x, (self.batch_size, 1)),
                    self.z: np.reshape(z, (self.batch_size, 1))
                })

                # update generator
                z = self.gen.sample(self.batch_size)
                loss_g, _ = session.run([self.loss_g, self.opt_g], {
                    self.z: np.reshape(z, (self.batch_size, 1))
                })

                if step % self.log_every == 0:
                    print('{}: {}\t{}'.format(step, loss_d, loss_g))                
                if step % 100 == 0 or step==0 or step == self.num_steps -1 :
                    self._plot_distributions(session)

训练结果

经过1200次迭代,可以看出结果是非常好的。
GAN网络入门(一)