莫烦pytorch(14)——GAN网络

1.画出大师作品(构造目标)

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(1)    # 设置种子
np.random.seed(1)

BATCH_SIZE = 64
LR_G = 0.0001           # (生成网络)
LR_D = 0.0001           # (判别网络)
N_IDEAS = 5             # 认为生成网络有五个灵感构成
ART_COMPONENTS = 15     # 15个部分

PAINT_POINTS=np.vstack([np.linspace(-1,1,ART_COMPONENTS) for _ in range(BATCH_SIZE)])  #水平拼接
print(PAINT_POINTS.shape)      #(64,15)
plt.plot(PAINT_POINTS[0],2*np.power(PAINT_POINTS[0],2)+1,c='#74BCFF', lw=3, label='upper bound')
plt.plot(PAINT_POINTS[0],2*np.power(PAINT_POINTS[0],2)+0,c='#FF9359', lw=3, label='lower bound')
plt.legend(loc="upper right")
plt.show()

莫烦pytorch(14)——GAN网络
PAINT_POINTS=np.vstack([np.linspace(-1,1,ART_COMPONENTS) for _ in range(BATCH_SIZE)])。其中np.vstack(a,b)是水平拼接。并且用了列表生成式,详情请看廖雪峰添加链接描述

2.定义大师作品的函数

def artist_workers():
    a=np.random.uniform(1,2,size=BATCH_SIZE)[:,np.newaxis]
    paintings=a*np.power(PAINT_POINTS,2)+(a-1)
    paintings=torch.from_numpy(paintings).float()
    return paintings

3.生成对抗网络的构建

G=nn.Sequential(
    nn.Linear(N_IDEAS, 128),  # random ideas (could from normal distribution)
    nn.ReLU(),
    nn.Linear(128, ART_COMPONENTS),  # making a painting from these random ideas
)

D=nn.Sequential(
    nn.Linear(ART_COMPONENTS, 128),  # receive art work either from the famous artist or a newbie like G
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid(),  # tell the probability that the art work is made by artist
)

opt_D=torch.optim.Adam(D.parameters(),lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)

plt.ion()

4.交叉训练

for step in range(10000):
    artist_paintings=artist_workers()
    G_ideas=torch.rand(BATCH_SIZE,N_IDEAS)
    G_paintings=G(G_ideas)
    prob_artist0=D(artist_paintings)
    prob_artist1=D(G_paintings)
    D_loss=-torch.mean(torch.log(prob_artist0)+torch.log(1-prob_artist1))
    G_loss=torch.mean(torch.log(1-prob_artist1))
    opt_D.zero_grad()
    D_loss.backward(retain_graph=True)
    opt_D.step()
    opt_G.zero_grad()
    G_loss.backward()
    opt_G.step()

    if step % 50 == 0:  # plotting
        plt.cla()
        plt.plot(PAINT_POINTS[0],G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)
        plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
        plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
        plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(),
                 fontdict={'size': 13})
        plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13})
        plt.ylim((0, 3));
        plt.legend(loc='upper right', fontsize=10);
        plt.draw();
        plt.pause(0.01)

    plt.ioff()
    plt.show()

莫烦pytorch(14)——GAN网络
下面简述一下交叉训练的过程:先调用函数生成大师的作品,在调用G生成伪造的画,分别对这两个产生的画进行判别,然后用两个计算loss的公式进行各自求解loss,其中D_loss=-(log(D(x)) + log(1-D(G(z))),因为我们希望仿造的画更小,真实的画更大,但是深度学习中只能计算最小值,所以加一个负号。