在TensorFlow中使用pipeline加载数据
正文共2028个字,6张图,预计阅读时间6分钟。
前面对TensorFlow的多线程做了测试,接下来就利用多线程和Queue pipeline地加载数据。数据流如下图所示:
首先,A、B、C三个文件通过RandomShuffle进程被随机加载到FilenameQueue里,然后Reader1和Reader2进程同FilenameQueue里取文件名读取文件,读取的内容再被放到ExampleQueue里。最后,计算进程会从ExampleQueue里取数据。各个进程独立操作,互不影响,这样可以加快程序速度。
我们简单地生成3个样本文件。
#生成三个样本文件,每个文件包含5列,假设前4列为特征,最后1列为标签
data = np.zeros([20,5]) np.savetxt('file0.csv', data, fmt='%d', delimiter=',') data += 1np.savetxt('file1.csv', data, fmt='%d', delimiter=',') data += 1np.savetxt('file2.csv', data, fmt='%d', delimiter=',')
然后,创建pipeline数据流。
#定义FilenameQueuefilename_queue = tf.train.string_input_producer(["file%d.csv"%i for i in range(3)])
#定义ExampleQueue
example_queue = tf.RandomShuffleQueue( capacity=1000, min_after_dequeue=0, dtypes=[tf.int32,tf.int32], shapes=[[4],[1]] )
#读取CSV文件,每次读一行
reader = tf.TextLineReader() key, value = reader.read(filename_queue)
#对一行数据进行解码
record_defaults = [[1], [1], [1], [1], [1]] col1, col2, col3, col4, col5 = tf.decode_csv( value, record_defaults=record_defaults) features = tf.stack([col1, col2, col3, col4])
#将特征和标签push进ExampleQueue
enq_op = example_queue.enqueue([features, [col5]])
#使用QueueRunner创建两个进程加载数据到ExampleQueue
qr = tf.train.QueueRunner(example_queue, [enq_op]*2)
#使用此方法方便后面tf.train.start_queue_runner统一开始进程
tf.train.add_queue_runner(qr) xs = example_queue.dequeue()
with tf.Session() as sess: coord = tf.train.Coordinator()
#开始所有进程 threads = tf.train.start_queue_runners(coord=coord)
for i in range(200): x = sess.run(xs) print(x) coord.request_stop() coord.join(threads)
以上我们采用for循环step_num次来控制训练迭代次数。我们也可以通过tf.train.string_input_producer的num_epochs参数来设置FilenameQueue循环次数来控制训练,当达到num_epochs时,TensorFlow会抛出OutOfRangeError异常,通过捕获该异常,停止训练。
filename_queue = tf.train.string_input_producer(["file%d.csv"%i for i in range(3)], num_epochs=6) ...
with tf.Session() as sess: sess.run(tf.initialize_local_variables()) #必须加上这句话,否则报错! coord = tf.train.Coordinator()
#开始所有进程
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop(): x = sess.run(xs) print(x)
except tf.errors.OutOfRangeError: print('Done training -- epch limit reached')
finally: coord.request_stop()
捕获到异常时,请求结束所有进程。
原文: 在TensorFlow中使用pipeline加载数据(https://goo.gl/jbVPjM)
原文链接:https://www.jianshu.com/p/12b52e54a63c
查阅更为简洁方便的分类文章以及最新的课程、产品信息,请移步至全新呈现的“LeadAI学院官网”:
www.leadai.org
请关注人工智能LeadAI公众号,查看更多专业文章
大家都在看
TensorFlow从1到2 | 第三章 深度学习革命的开端:卷积神经网络