手写数字识别(mnist)学习记录,tensorflow
主要记录在读入数据所产生的一些问题。
在tensorflow.keras.datasets中,提供mnist数据集的自动下载。下载之后,返回的是两个元组,分别是训练集和测试集,测试集60000张图片和label,训练集是10000张图片和label。图片大小分别是28*28。在fit时候,可以是numpy类型,也可以是dataset类型。对于两种不同的数据类型,对于网络模型的input_shape,也是有需要注意的地方。下面进行记录
- 用dataset有很多好处,可以用tensorflow中很多自带的对数据的处理函数,例如dataset.map(precoess)
- 用numpy格式读入进行训练,可以得到是成功的,每个batch是1.
- 将tuple转成dataset对象,用以上网络进行训练,发现不可行,对输入格式错误
- 因此,我们将其增加维度,使其变成28*28*1,发现在输入时候,格式依旧不对。因此,应该是增加前面的维度。
- 因此,我们将格式变成 none*28*28,训练得以成功,或者,我们可以将上面的输入改成28*28*1
- 经上面,如果dataset类型为输入集,input_shape如果为28*28,那么dataset须为 None*28*28,如果inpu_shape为28*28*1,那么dataset须为 None*28*28*1