tensorflow从训练到使用
一、Tensorflow 模型文件格式转换
pb_convert.py文件如下:
import tensorflow as tf
from tensorflow.python.platform import gfile
from google.protobuf import text_format
def convert_pb_to_pbtxt(filename):
with gfile.FastGFile(filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
tf.train.write_graph(graph_def, './', 'protobuf.pbtxt', as_text=True)
return
def convert_pbtxt_to_pb(filename):
"""Returns a `tf.GraphDef` proto representing the data in the given pbtxt file.
Args: filename: The name of a file containing a GraphDef pbtxt (text-formatted `tf.GraphDef` protocol buffer data). """
with tf.gfile.FastGFile(filename, 'r') as f:
graph_def = tf.GraphDef()
file_content = f.read() # Merges the human-readable string in `file_content` into `graph_def`.
text_format.Merge(file_content, graph_def)
tf.train.write_graph( graph_def , './' , 'protobuf.pb' , as_text = False )
return
调用方式如下:
import pb_convert
pb_convert.convert_pb_to_pbtxt('classify_image_graph_def.pb')
二、将模型.pb文件在tensorboard中展示结构
import tensorflow as tf
model = 'model.pb' #请将这里的pb文件路径改为自己的
graph = tf.get_default_graph()
graph_def = graph.as_graph_def()
graph_def.ParseFromString(tf.gfile.FastGFile(model, 'rb').read())
tf.import_graph_def(graph_def, name='graph')
summaryWriter = tf.summary.FileWriter('log/', graph)
三、模型持久化
1.save模型
2.使用模型预测
四、使用HelloWorld训练和预测
训练与预测代码如下所示:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
print(mnist.train.images.shape, mnist.train.labels.shape)
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)
import tensorflow as tf
sess = tf.InteractiveSession()
with tf.name_scope('input'):
x = tf.placeholder(tf.float32, [None, 784], name ='x_input')
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
#########################################
saver = tf.train.Saver()
training_flag = 0
if training_flag == 1:
print ('training')
tf.global_variables_initializer().run()
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
train_step.run({x: batch_xs, y_: batch_ys})
saver.save(sess, "./cmodel.ckpt")
else:
print ('prediction')
saver.restore(sess, "./cmodel.ckpt")
#########################################
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
with tf.name_scope('accuracy'):
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
graph = tf.get_default_graph()
summaryWriter = tf.summary.FileWriter('log/', graph)
tf.train.write_graph( graph , './' , 'test.pb' , as_text = False )
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))
四、基于MNIST的HelloWorld学习
1.tensorflow训练model
2.保存model,使用tensorboard查看graph
保存model:
graph = tf.get_default_graph()
summaryWriter = tf.summary.FileWriter('log/', graph)
使用tensorboard进行查看:
python ~/.local/lib/python3.5/site-packages/tensorboard/main.py --logdir=log
3.加载model然后对测试用例进行predict
1)使用pb文件
加载
with tf.gfile.FastGFile(os.path.join(FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
预测
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
predictions = sess.run(softmax_tensor,
{'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions)
五、classify_image学习
1.运行例程
model里面自带的有经典的GoogleInceptionNet模型。因此可以直接运行tensorflow\models\tutorials\image\imagent下面的classify_image.py。该命令可以直接下载运行好的模型并识别对应的图片。其运行方式为:
进入Imagenet目录后运行classify_image.py脚本。具体为:
python classify_image.py --model_dir ~/Image --image_file ~/Image/a.jpg
(其中–model_dir表示模型将要下载的地址。 --image_file表示模型将要识别的图片)
结果为:
2.使用tf.train.server保存model