莫烦pytorch(14)——GAN网络
1.画出大师作品(构造目标)
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1) # 设置种子
np.random.seed(1)
BATCH_SIZE = 64
LR_G = 0.0001 # (生成网络)
LR_D = 0.0001 # (判别网络)
N_IDEAS = 5 # 认为生成网络有五个灵感构成
ART_COMPONENTS = 15 # 15个部分
PAINT_POINTS=np.vstack([np.linspace(-1,1,ART_COMPONENTS) for _ in range(BATCH_SIZE)]) #水平拼接
print(PAINT_POINTS.shape) #(64,15)
plt.plot(PAINT_POINTS[0],2*np.power(PAINT_POINTS[0],2)+1,c='#74BCFF', lw=3, label='upper bound')
plt.plot(PAINT_POINTS[0],2*np.power(PAINT_POINTS[0],2)+0,c='#FF9359', lw=3, label='lower bound')
plt.legend(loc="upper right")
plt.show()
PAINT_POINTS=np.vstack([np.linspace(-1,1,ART_COMPONENTS) for _ in range(BATCH_SIZE)])
。其中np.vstack(a,b)是水平拼接。并且用了列表生成式,详情请看廖雪峰添加链接描述
2.定义大师作品的函数
def artist_workers():
a=np.random.uniform(1,2,size=BATCH_SIZE)[:,np.newaxis]
paintings=a*np.power(PAINT_POINTS,2)+(a-1)
paintings=torch.from_numpy(paintings).float()
return paintings
3.生成对抗网络的构建
G=nn.Sequential(
nn.Linear(N_IDEAS, 128), # random ideas (could from normal distribution)
nn.ReLU(),
nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideas
)
D=nn.Sequential(
nn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like G
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid(), # tell the probability that the art work is made by artist
)
opt_D=torch.optim.Adam(D.parameters(),lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)
plt.ion()
4.交叉训练
for step in range(10000):
artist_paintings=artist_workers()
G_ideas=torch.rand(BATCH_SIZE,N_IDEAS)
G_paintings=G(G_ideas)
prob_artist0=D(artist_paintings)
prob_artist1=D(G_paintings)
D_loss=-torch.mean(torch.log(prob_artist0)+torch.log(1-prob_artist1))
G_loss=torch.mean(torch.log(1-prob_artist1))
opt_D.zero_grad()
D_loss.backward(retain_graph=True)
opt_D.step()
opt_G.zero_grad()
G_loss.backward()
opt_G.step()
if step % 50 == 0: # plotting
plt.cla()
plt.plot(PAINT_POINTS[0],G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)
plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(),
fontdict={'size': 13})
plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13})
plt.ylim((0, 3));
plt.legend(loc='upper right', fontsize=10);
plt.draw();
plt.pause(0.01)
plt.ioff()
plt.show()
下面简述一下交叉训练的过程:先调用函数生成大师的作品,在调用G生成伪造的画,分别对这两个产生的画进行判别,然后用两个计算loss的公式进行各自求解loss,其中D_loss=-(log(D(x)) + log(1-D(G(z)))
,因为我们希望仿造的画更小,真实的画更大,但是深度学习中只能计算最小值,所以加一个负号。