Tensorflow saver.restore()不恢复网络
我完全失去了对的tensorflow保护方法。Tensorflow saver.restore()不恢复网络
我试图遵循的基本tensorflow深层神经网络模型的教程。我想弄清楚如何训练网络几次迭代,然后在另一个会话中加载模型。
with tf.Session() as sess:
graph = tf.Graph()
x = tf.placeholder(tf.float32,shape=[None,784])
y_ = tf.placeholder(tf.float32, shape=[None,10])
sess.run(global_variables_initializer())
#Define the Network
#(This part is all copied from the tutorial - not copied for brevity)
#See here: https://www.tensorflow.org/versions/r0.12/tutorials/mnist/pros/
跳过培训。
#Train the Network
train_step = tf.train.AdamOptimizer(1e-4).minimize(
cross_entropy,global_step=global_step)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
saver = tf.train.Saver()
for i in range(101):
batch = mnist.train.next_batch(50)
if i%100 == 0:
train_accuracy = accuracy.eval(feed_dict=
{x:batch[0],y_:batch[1]})
print 'Step %d, training accuracy %g'%(i,train_accuracy)
train_step.run(feed_dict={x:batch[0], y_: batch[1]})
if i%100 == 0:
print 'Test accuracy %g'%accuracy.eval(feed_dict={x:
mnist.test.images, y_: mnist.test.labels})
saver.save(sess,'./mnist_model')
控制台打印出:
步骤0,训练精度0.16
测试精度0.0719
步骤100,训练精度0.88
测试精度0.8734
接下来,我要加载模型
with tf.Session() as sess:
saver = tf.train.import_meta_graph('mnist_model.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
sess.run(tf.global_variables_initializer())
现在我想重新测试,看看模型加载
print 'Test accuracy %g'%accuracy.eval(feed_dict={x:
mnist.test.images, y_: mnist.test.labels})
控制台打印出:
测试精度0.1151
它似乎没有显示模型正在保存任何数据?我究竟做错了什么?
当您保存您的模型,一般而局部变量是不是所有的全局变量保存在外部文件。您可以查看此answer以了解其差异。
您的恢复代码中的错误正在调用tf.global_variable_initializer()
后saver.restore()
。该saver.restore
文档提到,
变量恢复没有被初始化,如恢复本身就是一种方式来初始化变量。
因此,尝试删除线,
sess.run(tf.global_variables_initializer())
理论上,应该将其替换为,
sess.run(tf.local_variables_initializer())
谢谢,这似乎已经解决了我的问题!如果文档声明'saver.restore()'是一个初始化过程,那么'sess.run(tf。local_variables_initializer())用于任何目的? 这似乎也表明,教程,如[一个快速完整的教程来保存和恢复Tensorflow模型](http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-教程/)显示不正确的用法,不是吗? –
你应该检查['tf.local_variables()'](https://www.tensorflow.org/versions/r1.0/api_docs/python/tf/local_variables)。如果这个列表非空,则需要它 – martianwars
你不应该运行'sess.run(tf.global_variables_initializer())'恢复权重后。这将重置您的所有权重 – martianwars