30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试

写在前头

本文与 30天 吃掉那只 tensorflow 的原文有较大的出入,只是借鉴他使用的数据集并且整个构建网络的过程;数据处理部分和训练网络部分都是自己设计的,测试部分我也使用了自己上网找到图片并进行了处理;仅供大家参考

1. Cifar10数据集的介绍、获取

30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试

cifar10 数据集使用了 10 种分类的训练数据;标签从 0-9
获取方式: 如下:
使用强大的 jupyterlab 工具来查看一下 load_data 的用法;这个函数会自动帮你下载数据集(这种通过命令行直接从 tensorflow 官网中下载大约需要2个小时,可以通过别的方式获取,请自行查阅相关文章)
30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试

2. 训练集数据可视化

load_data 函数最终返回 2 个 tuple;所以我们用以下代码来接受返回值
30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试

可视化一下前三个训练数据

  • 我们可以看到训练数据的前三个的标签为 6,9,9
  • 也可以看到我们可视化出来的结果前三张图分别是:青蛙、卡车、卡车
    30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试

3. 简单数据处理:将标签进行 one-hot 编码转换

对于多分类问题,我们一般会把标签转换成 one-hot 编码的形式,为了以后更容易计算;所以我们在这里用到了 to_categorical 函数来转换
30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试
还记得么 y_train 的前三个标签是 6,9,9,现在变成了对应位置为 1 其他位置为0 的 one-hot 编码,这是为了在后面的计算中对应的使用 softmax 函数计算出概率分布,并通过相应位置的概率分布计算损失;

4. 构建网络模型

30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试

可以用 model.summary() 来查看你已经建立的网络结构还有需要训练的参数
30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试

5. 模型训练

30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试

6. 训练数据可视化

先来看看 history 里有哪些可以可视化的数据:
发现有 ['loss', 'acc', 'val_loss', 'val_acc'] 我们使用索引把他们的值分别拿出来并展示以下,他们每一组数据都存在一个列表里,我们就用这些数据可视化每一个 epoch 的训练过程
30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试
把 loss 和 val_loss 呈现在一张图中
30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试
把 acc 和 val_acc 也放在同一个图中进行对比
30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试
两个图像表明,数据被训练的很好,也没有存在过拟合现象。

7. 数据评估

数据评估的准确率甚至略高于训练集,这也是很好的结果
30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试
下面,我从网上随便照一张图片,用我们训练好的模型来检测一下训练成果。下面是我找的网图;下面的代码中我将演示如何处理这张网图,然后用模型来进行预测。
30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试

7.1 数据处理

  • 可以看到,刚读出来的图片是 3 个通道的彩图;我们上面训练的也使用的 3 通道彩图;
  • 所以我们要对这个图片进行 resize;但是 resize 操作不能直接对 3 通道的图片做;所以:
  • 我们按照 opencv 读图片的通道顺序 b, g, r (注意不是 rgb) 使用 cv2.split() 函数对数据解包;得到了每个通道之后我们分别做 resize 操作,最后再用 cv2.merge() 将三个通道叠加起来;这样我们就可以得到我们想要的结果了

30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试
30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试

7.2 将数据送到模型中测试

但是直接这样送入模型会报错;因为你用测试集进行测试的时候你的数据是 4 维的,(10000, 32, 32, 3) 这个 10000 代表的是10000张测试图片;所以我们要进行测试,要把这个测试的数据升高一个维度,或者把 n 张图片绑到一个数组里面送去测试,即把矩阵变成 (n, 32, 32, 3) 这种模式;
30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试
30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试
为了简单起见,我直接用两张它自己作为测试集;
注意!!! 把这两张图片放在一个列表中之后,千万不要忘记将这个列表用 numpy 转换成一个矩阵;因为模型的输入只能是矩阵而不能是列表
30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试
按照概率分布中的结果,我们筛选出最大的值来看一看是不是我们想要的标签
30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试
使用 .argmax() 函数返回数组中最大的值的索引;返回的位置是 1;我们回去看 cifar10 数据集中的 1 代表的是 “汽车”
30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试
所以可以看出来,模型的训练效果还是不错的。模型的保存部分,大家可以翻看我的上篇文章。

写在后面

如有错误,敬请指正;欢迎交流
个人的微信号:

30天吃掉那只tensorflow之(2):使用 cifar10 数据集来训练网络并测试