不使用ai库的gan网络实现手写mnist数字生成

一、 目的
利用gan神经网络,实现了mnist数据的生成任务。二、 网络结构设计:
我所作的gan网络结构如下:
生成器:
不使用ai库的gan网络实现手写mnist数字生成
判别器:
不使用ai库的gan网络实现手写mnist数字生成

三、 原理介绍
生成对抗网络(以下简称GAN)是通过让两个神经网络相互博弈的方式进行学习,可以根据原有的数据集生成以假乱真的新的数据,举个不是很恰当的例子,类似于造假鞋,莆田艺术家通过观察真鞋,模仿真鞋的特点造出假鞋并卖给消费者,消费者收到鞋子后将它与网上的真鞋信息进行对比找瑕疵,并给出反馈,比如标不正,气垫弹性不好,莆田艺术家根据消费者给出的反馈积极地改进工艺,经过不懈努力后最终造出了可以忽悠消费者的假鞋。在上述情景中,莆田艺术家相当于生成器,消费者相当于辨别器,在造假的过程中,生成器和判别器一直处于对抗状态。我们把上述情景抽象为神经网络。首先,通过对生成器输入一个分布的数据,生成器通过神经网络模仿生成出一个输出(假鞋),将假鞋与真鞋的信息共同输入到判别器中。然后,判别器通过神经网络学着分辨两者的差异,做一个分类判断出这双鞋是真鞋还是假鞋。这样,生成器不断训练为了以假乱真,判别器不断训练为了区分二者。最终,生成器真能完全模拟出与真实的数据一模一样的输出,判别器已经无力判断。 算法流程如下
不使用ai库的gan网络实现手写mnist数字生成
以上是gan论文中的算法流程,说的直白一些就是,在每个epoch中先用生成器生成一组假数据,与真数据一块放入判别器中进行识别,训练好判别器后保持参数不变,将生成器与判别器相连,去训练生成器的参数,如此反复。
四、结果与分析
看一下迭代10次的效果:
不使用ai库的gan网络实现手写mnist数字生成
不使用ai库的gan网络实现手写mnist数字生成

感觉整体都很像8,可能所有的数字和8都有重合的原因,导致训练的结果总是往8上靠拢。下面介绍一下我的代码框架: 如果只是想直接运行的话只要先运行xunlian.py在里面我会有注释,最后会生成10张图片保存在jieguo.npy文件中(当前的是我训练过的,如果要训练得换一个名字保存),查看图片运行xianshi.py里面用的是plt图片显示指令,会展示预测的图片。
五、总结
这次实验没有使用任何ai的工具包,只用了numpy和一些显示图片的辅助库,卷积的关于mnist识别率在95%左右,可能相比与tensorflow和pytorch还是略显得差一些,所以整体的实验结果也没有那么的完美,但能把数字的一个模糊轮廓实现还是说明实验是成功的,经过这次实验收获还是挺多的,模型的构建上确实下了一番功夫,比如一开始忘记交替训练了,生成了一堆二维码,然后痛定思痛,又看了一些博客和论文才发现犯了没有交替训练的低级错误,纠正以后的实验结果还是符合预期的。
备注:如果想要源码和数据,可以通过链接//download.csdn.net/download/m0_37922163/12107799获取。
如果对时间没有要求的,可以将邮箱联系方式写在评论区,我可以邮箱发送。