tensorflow 在cifar10上训练resnet50
训练环境:windows10+python3.6.5+cuda-v9.0+tensorflow_gpu-1.10.0
目录
1、CIFAR-10介绍
CIFAR-10和CIFAR-100是8000万个微小图像数据集的标记子集。他们是由Alex Krizhevsky,Vinod Nair和Geoffrey Hinton收集的。
CIFAR-10数据集由10个类中的60000个32x32彩色图像组成,每个类有6000个图像。有50000个训练图像和10000个测试图像。
数据集分为五个训练批次和一个测试批次,每个批次有10000个图像。测试批次包含来自每个类别的1000个随机选择的图像。训练批次以随机顺序包含剩余图像,但是一些训练批次可能包含来自一个类别的更多图像而不是另一个类别。在它们之间,训练批次包含来自每个类别的5000个图像。
以下是数据集中的类,以及每个中的10个随机图像:
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
2、resnet50网络介绍
根据何凯明在论文 Deep Residual Learning for Image Recognition 4.1单节中的描述,50层resnet是将34层resnet中的两层瓶颈块替换成三层瓶颈块,瓶颈块的结构如下面图表中所示。
34层resnet见下图最右的结构
50层resnet结构见下图
3、使用TensorFlow Slim微调模型训练
TF-slim是Google公司公布的一种新的轻量级的、用于定义、训练和评估复杂模型的TensorFlow高级API。它提供的接口可以帮助我们从头开始训练模型,或者从预先训练的网络权值中对它们进行微调。提供的模型包括VGG16,VGG19,Inception V1~4,ResNet50,Resnt101,MobileNet等。
3.1、下载TF-slim源码
下载命令:
git clone https://github.com/tensorflow/models/
从research目录下将slim目录复制出来。
3.2、下载cifar10数据并转换格式为trconf
下载命令:
python slim/download_and_convert_data.py --dataset_name=cifar10 --dataset_dir=data
3.3、下载resnet50模型
下载地址:
http://download.tensorflow.org/models/resnet_v2_50_2017_04_14.tar.gz
解压之后放在pretrained目录,新建一个训练用的目录train_dir,一个测试用的目录eval_dir,最终形成的目录如下:
---slim
---data
------ cifar10_test.tfrecord
------ cifar10_train.tfrecord
------ labels.txt
---train-dir
-- eval-dir
-- pretrained
------ resnet_v2_50.ckpt
3.4、训练resnet50
训练全部层的命令如下:
python slim/train_image_classifier.py
--train_dir=train_dir \
--dataset_name=cifar10 \
--dataset_split_name=train \
--dataset_dir=data \
--model_name=resnet_v2_50 \
--checkpoint_path=pretrained/resnet_v2_50.ckpt \
--checkpoint_exclude_scopes=resnet_v2_50/logits \
--max_number_of_steps=50000 \
--batch_size=16 \
--learning_rate=0.001 \
--log_every_n_steps=100 \
--optimizer=adma
优化器选择adma,初始学习率为0.001,一共训练50000步,其他参数都为默认,详情可查看slim/train_image_classifier.py。经过50000训练后,loss值为0.18,如下图所示。
3.5、验证模型准确率
使用如下命令验证模型准确率,准确率达到了95.17%。
python slim/eval_image_classifier.py \
--checkpoint_path=train_dir \
--eval_dir=eval_dir \
--dataset_name=cifar10 \
--dataset_split_name=test \
--dataset_dir=data \
--model_name=resnet_v2_50
验证结果如下图所示:
项目地址:https://github.com/tryrus/Cifar10-Classification-useResnet50