使用tf.estimator初始化tf.contrib.data.Iterator
问题描述:
如果tf.estimator.Estimator
也被使用,应该如何初始化tf.contrib.data.Iterator
?使用tf.estimator初始化tf.contrib.data.Iterator
其中一个问题是,输入图形(TF图表处理输入的一部分)应该在intput_fn()
被定义 - 东阳tf.estimator创建seprate曲线图。
这个要求使得很难访问迭代器init ops
并通过它们to tf.estimator
(以钩子的形式调用train/evaluate/predict
时可以完成操作)。
答
使用SessionManager
作为钩子可以解决相同的问题。
sm = tf.train.SessionManager(local_init_op=iterator_init_op)
...
estimator = tf.train.Estimator(...)
estimator.train(input_fn, hooks=[sm], steps=None, max_steps=None)
答
一种选择是包装你input_fn
,设置了一个简单的SessionRunHook init_hook
另一个函数内。所有操作都在input_fn
之内定义,该操作在与您的模型的其余部分相同的图形中调用,但是您可以将iterator_init_op
设置为init_hook
上的一个属性。
def get_input_fn(mode="train"):
init_hook = IteratorInitHook()
def input_fn():
...
iterator = dataset.make_initializable_iterator()
init_hook.iterator_init_op = iterator.initializer
return input_fn, init_hook
class IteratorInitHook(tf.train.SessionRunHook):
def after_create_session(self, session, coord):
session.run(self.iterator_init_op)
现在构建Experiment
的时候,你可以得到这些输入功能,和init钩子,创建火车/ EVAL会话时被调用。它应该等效于estimator.train
。
train_input_fn, train_init_hook = get_input_fn("train")
test_input_fn, test_init_hook = get_input_fn("test")
return tf.contrib.learn.Experiment(
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=test_input_fn,
train_monitors=[train_init_hook],
eval_hooks=[test_init_hook],
)
问题是将init操作传递给estimator.train。 你不应该在你的钩子代码中访问'init_op'(因为它应该在init_fn中定义) 因为有1个图需求,所以init_fn不能被调用(!) 我所做的是保存init_op按预设名称汇入藏品中,然后在您告诉我的勾中读取藏品,我会更多地思考它,扩大您的答案以给出完整答案。 – Pietrko