如果我想要使用无法使用TensorFlow加载到内存中的大型数据集,该怎么办?
问题描述:
我想使用一个无法加载到内存中的大型数据集来训练带有TensorFlow的模型。但我不知道我应该做什么。如果我想要使用无法使用TensorFlow加载到内存中的大型数据集,该怎么办?
我已阅读了一些关于TFRecords
文件格式和官方文档的好帖子。公交车我仍然无法弄清楚。
TensorFlow是否有完整的解决方案?
答
考虑使用tf.TextLineReader
,它与tf.train.string_input_producer
一起允许您从磁盘上的多个文件(如果您的数据集足够大以至于需要将其分散到多个文件中)加载数据。
见https://www.tensorflow.org/programmers_guide/reading_data#reading_from_files
代码段从上面的链接:
filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(
value, record_defaults=record_defaults)
features = tf.stack([col1, col2, col3, col4])
with tf.Session() as sess:
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(
value, record_defaults=record_defaults)
features = tf.stack([col1, col2, col3, col4])
with tf.Session() as sess:
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(1200):
# Retrieve a single instance:
example, label = sess.run([features, col5])
coord.request_stop()
coord.join(threads)i in range(1200):
# Retrieve a single instance:
example, label = sess.run([features, col5])
coord.request_stop()
coord.join(threads)
答
通常情况下,您无论如何都会使用批处理智能培训,因此您可以即时加载数据。例如,对于图像:
for bid in nrBatches:
batch_x, batch_y = load_data_from_hd(bid)
train_step.run(feed_dict={x: batch_x, y_: batch_y})
因此,您可以实时加载每个批次,只加载需要在任何特定时刻加载的数据。当然你的训练时间会增加,而使用硬盘代替内存来加载数据。
谢谢您的anwser。但是,如果CSV文件中有**列**,该怎么办?我必须写很多col1,col2,col3 ...等等?以及如何从二进制文件读取数据? – secsilm
@secsilm是的,您需要在您的CSV中为每列添加“col1”,“col2”等。记住'col1'只是一个变量名,所以你可以给它一个更多的助记符名称,比如'price'或者其他什么。有关二进制文件,请参阅https://www.tensorflow.org/api_docs/python/tf/FixedLengthRecordReader – Insectatorious