在Tensorflow中只计算一次子图
问题描述:
我正在用Tensorflow构建一个深度学习模型。在训练之前,我会做一些计算,如反向传播。但只需要计算一次。下面是我的伪代码:在Tensorflow中只计算一次子图
class residual_net()
def pseudo_bp(self):
# do something...
self.bp = ...
def build_net(self):
# build a residual_network....
# utilize the variable in pseudo_bp
rn.output = func(self.bp)
def run():
rn = residual_net()
rn.pseudo_bp()
rn.deep_residual_network()
sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())
for i in range(1000):
err = tf.reduce_mean(rn.output, labels)
train = tf.train.GradientDescentOptimizer(learning_rate).minimize(err)
sess.run(train, feed_dict=train_feed_dict)
不知pseudo_bp
将在每次迭代运行?如果是的话,我怎么才能让它运行一次?提前致谢!
编辑: 最新的错误:
Traceback (most recent call last):
File "run.py", line 124, in <module>
sess.run(pseudo_bp, feed_dict=feed_dict)
File "/Users/yobichi/bigdata/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 717, in run
run_metadata_ptr)
File "/Users/yobichi/bigdata/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 902, in _run
fetch_handler = _FetchHandler(self._graph, fetches, feed_dict_string)
File "/Users/yobichi/bigdata/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 358, in __init__
self._fetch_mapper = _FetchMapper.for_fetch(fetches)
File "/Users/yobichi/bigdata/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 178, in for_fetch
(fetch, type(fetch)))
TypeError: Fetch argument None has invalid type <class 'NoneType'>
你有什么想法?
答
在TensorFlow中,您从构建tf.Graph
开始。该图由变量,操作和占位符组成。然后开始tf.Session()
,您可以在其中执行操作并更新变量。
在这种情况下,我认为psuedo_bp
最终需要您计算一些操作(如tf.matmul
)。 sess
就像一个指针,只要你运行sess.run(op)
就会执行各种tf.Operation
。您提供一些输入来填充占位符(feed_dict
)。
因此,您只会执行sess.run(op)
for for循环的第一次迭代。这里是结果代码 -
class residual_net()
def pseudo_bp(self):
# do something...
return op
def build_net(self):
# build a residual_network....
rn.output = sth
def run():
rn = residual_net()
operation = rn.pseudo_bp()
rn.build_net()
err = tf.reduce_mean(rn.output, labels)
train = tf.train.GradientDescentOptimizer(learning_rate).minimize(err)
# Graph has been built completely. Begin tf.Session()
sess = tf.Session()
sess.run(tf.initialize_all_variables())
for i in range(1000):
# Carry out the training in each iteration
# Note that train is an operation here
sess.run(train, feed_dict=feed_dict)
if i == 0:
# Execute `operation` for the first iteration
result = sess.run(operation, feed_dict=feed_dict)
感谢您的回答。对不起,我错过了我原来的问题中的一个非常重要的信息。 'pseudo_bp'中计算的变量将在'build_net'中使用。 'build_net'也被链接到'train'操作。所以我想知道'pseudo_bp'是否仍然会在下面的迭代中运行? – southdoor
我做了你的建议,我得到了一个错误,我更新了这个问题,以获得更多的日志信息。你可以看一下吗?谢谢! – southdoor
你给'pseudo_bp'分配了什么?您可能需要将它分配给'rn.bp' – martianwars