故障恢复检查点TensorFlow网

问题描述:

我已经建立了一个自动编码器,将VGG19.relu4_1的激活“转换”为像素。我使用tensorflow.contrib.layers中的新便利功能(如在TF 0.10rc0中)。该代码与TensorFlow的CIFAR10教程具有相似的布局,其中train.py将训练和检查点设置为磁盘模型,一个eval.py轮询新检查点文件并对它们运行推断。故障恢复检查点TensorFlow网

我的问题是,评估从来没有像培训一样好,既不是在损失函数的价值方面,也不是当我看输出图像时(即使在与培训相同的图像上运行时)。这让我觉得它与恢复过程有关。

当我看着TensorBoard培训的输出时,它看起来不错(最终),所以我不认为我的网本身有什么问题。

我的网看起来像这样:

import tensorflow.contrib.layers as contrib 
bn_params = {                    
    "is_training": is_training, 
    "center": True, 
    "scale": True 
}                                      

tensor = contrib.convolution2d_transpose(vgg_output, 64*4, 4,        
    stride=2, 
    normalizer_fn=contrib.batch_norm, 
    normalizer_params=bn_params, 
    scope="deconv1")               
tensor = contrib.convolution2d_transpose(tensor, 64*2, 4,        
    stride=2, 
    normalizer_fn=contrib.batch_norm, 
    normalizer_params=bn_params, 
    scope="deconv2") 
. 
. 
. 

而在train.py我这样做是为了保存检查点:

variable_averages = tf.train.ExponentialMovingAverage(mynet.MOVING_AVERAGE_DECAY) 
variables_averages_op = variable_averages.apply(tf.trainable_variables()) 

with tf.control_dependencies([apply_gradient_op, variables_averages_op]): 
    train_op = tf.no_op(name='train') 

while training: 
    # train (with batch normalization's is_training = True) 
    if time_to_checkpoint: 
     saver.save(sess, checkpoint_path, global_step=step) 

eval.py我这样做:

# run code that creates the net 

variable_averages = tf.train.ExponentialMovingAverage(
        mynet.MOVING_AVERAGE_DECAY) 
saver = tf.train.Saver(variable_averages.variables_to_restore()) 

while polling: 
    # sleep and check for new checkpoint files 
    with tf.Session() as sess: 
     init = tf.initialize_all_variables() 
     init_local = tf.initialize_local_variables() 
     sess.run([init, init_local]) 
     saver.restore(sess, checkpoint_path) 

     # run inference (with batch normalization's is_training = False) 

The loss function

蓝色是训练损失,橙色是eval损失。

问题是我直接使用tf.train.AdamOptimizer()。在优化过程中,没有调用contrib.batch_norm中定义的操作来计算输入的运行平均值/方差,因此平均值/方差总是为0.0/1.0。

解决方法是向GraphKeys.UPDATE_OPS集合添加依赖项。在contrib模块中已经定义了一个功能(optimize_loss()

+0

感谢您的解决。我是唯一一个认为这应该被充分记录/修复的人。我认为'optimize_loss()'函数只是optimizer.minimize(损失,步骤)的快捷方式,而不是其他contrib.layers像宣传的那样工作所必需的。 – DomJack