Tensorflow恢复模型
问题描述:
我想恢复我的Tensorflow模型 - 这是一个线性回归网络。我确信我做错了事,因为我的预测不好。当我训练时,我有一套测试装置。我的测试集预测看起来不错,但是当我尝试恢复相同的模型时,预测看起来很差。Tensorflow恢复模型
这是我如何保存模型:
with tf.Session() as sess:
saver = tf.train.Saver()
init = tf.global_variables_initializer()
sess.run(init)
training_data, ground_truth = d.get_training_data()
testing_data, testing_ground_truth = d.get_testing_data()
for iteration in range(config["training_iterations"]):
start_pos = np.random.randint(len(training_data) - config["batch_size"])
batch_x = training_data[start_pos:start_pos+config["batch_size"],:,:]
batch_y = ground_truth[start_pos:start_pos+config["batch_size"]]
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
train_acc, train_loss = sess.run([accuracy, cost], feed_dict={x: batch_x, y: batch_y})
sess.run(optimizer, feed_dict={x: testing_data, y: testing_ground_truth})
test_acc, test_loss = sess.run([accuracy, cost], feed_dict={x: testing_data, y: testing_ground_truth})
samples = sess.run(pred, feed_dict={x: testing_data})
# print samples
data.compute_acc(samples, testing_ground_truth)
print("Training\tAcc: {}\tLoss: {}".format(train_acc, train_loss))
print("Testing\t\tAcc: {}\tLoss: {}".format(test_acc, test_loss))
print("Iteration: {}".format(iteration))
if iteration % config["save_step"] == 0:
saver.save(sess, config["save_model_path"]+str(iteration)+".ckpt")
下面是我的测试集的一些例子。你会发现My prediction
比较接近Actual
:
My prediction: -12.705 Actual : -10.0
My prediction: 0.000 Actual : 8.0
My prediction: -14.313 Actual : -23.0
My prediction: 17.879 Actual : 13.0
My prediction: 17.452 Actual : 24.0
My prediction: 22.886 Actual : 29.0
Custom accuracy: 5.0159861487
Training Acc: 5.63836860657 Loss: 25.6545143127
Testing Acc: 4.238052845 Loss: 22.2736053467
Iteration: 6297
那么这里我如何恢复模式:
with tf.Session() as sess:
saver = tf.train.Saver()
saver.restore(sess, config["retore_model_path"]+"3000.ckpt")
init = tf.global_variables_initializer()
sess.run(init)
pred = sess.run(pred, feed_dict={x: predict_data})[0]
print("Prediction: {:.3f}\tGround truth: {:.3f}".format(pred, ground_truth))
但这里的预测是什么样子。你会发现,Prediction
总是对周围0:
Prediction: 0.355 Ground truth: -22.000
Prediction: -0.035 Ground truth: 3.000
Prediction: -1.005 Ground truth: -3.000
Prediction: -0.184 Ground truth: 1.000
Prediction: 1.300 Ground truth: 5.000
Prediction: 0.133 Ground truth: -5.000
这里是我的tensorflow版本(是的,我需要更新):
Python 2.7.6 (default, Oct 26 2016, 20:30:19)
[GCC 4.8.4] on linux2
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
>>> print(tf.__version__)
0.12.0-rc1
不知道这是否会有所帮助,但我想放置saver.restore()
在sess.run(init)
之后致电,我得到的预测完全相同。我认为这是因为sess.run(init)
初始化变量。
更改排序是这样的:
sess.run(init)
saver.restore(sess, config["retore_model_path"]+"6000.ckpt")
但随后的预测是这样的:
Prediction: -15.840 Ground truth: 2.000
Prediction: -15.840 Ground truth: -7.000
Prediction: -0.000 Ground truth: 12.000
Prediction: -15.840 Ground truth: -9.000
Prediction: -15.175 Ground truth: -27.000
答
当您从检查点恢复,你不初始化变量。正如你在问题结尾处注意到的那样。
init = tf.global_variables_initializer()
sess.run(init)
覆盖刚刚恢复的变量。哎呀! :)
评论这两条线,我怀疑你会很好去。