使用tf.estimator初始化tf.contrib.data.Iterator

问题描述:

如果tf.estimator.Estimator也被使用,应该如何初始化tf.contrib.data.Iterator使用tf.estimator初始化tf.contrib.data.Iterator

其中一个问题是,输入图形(TF图表处理输入的一部分)应该在intput_fn()被定义 - 东阳tf.estimator创建seprate曲线图。

这个要求使得很难访问迭代器init ops并通过它们to tf.estimator(以钩子的形式调用train/evaluate/predict时可以完成操作)。

使用SessionManager作为钩子可以解决相同的问题。

sm = tf.train.SessionManager(local_init_op=iterator_init_op) 
... 
estimator = tf.train.Estimator(...) 
estimator.train(input_fn, hooks=[sm], steps=None, max_steps=None) 
+0

问题是将init操作传递给estimator.train。 你不应该在你的钩子代码中访问'init_op'(因为它应该在init_fn中定义) 因为有1个图需求,所以init_fn不能被调用(!) 我所做的是保存init_op按预设名称汇入藏品中,然后在您告诉我的勾中读取藏品,我会更多地思考它,扩大您的答案以给出完整答案。 – Pietrko

一种选择是包装你input_fn,设置了一个简单的SessionRunHook init_hook另一个函数内。所有操作都在input_fn之内定义,该操作在与您的模型的其余部分相同的图形中调用,但是您可以将iterator_init_op设置为init_hook上的一个属性。

def get_input_fn(mode="train"): 
    init_hook = IteratorInitHook() 

    def input_fn(): 
     ... 
     iterator = dataset.make_initializable_iterator() 
     init_hook.iterator_init_op = iterator.initializer 

    return input_fn, init_hook 

class IteratorInitHook(tf.train.SessionRunHook): 

    def after_create_session(self, session, coord): 
     session.run(self.iterator_init_op) 

现在构建Experiment的时候,你可以得到这些输入功能,和init钩子,创建火车/ EVAL会话时被调用。它应该等效于estimator.train

train_input_fn, train_init_hook = get_input_fn("train") 
test_input_fn, test_init_hook = get_input_fn("test") 

return tf.contrib.learn.Experiment(
    estimator=estimator, 
    train_input_fn=train_input_fn, 
    eval_input_fn=test_input_fn, 
    train_monitors=[train_init_hook], 
    eval_hooks=[test_init_hook], 
)