tensorflow中的batch数据打印
之前一直困惑怎么打印batch返回值,后来发现一个迭代就可以了,因为这个batch返回的就是迭代器。
首先 batch 就是把数据分成几组,前面己组的数据大小都是batch的大小,但是最后一组可以能数据量不多,就只有一点
如下,先创建一个20行3列的数据
temp = tf.Variable(tf.random.normal(shape=(20, 3), mean=0, stddev=0.01, dtype=tf.float32))
print("temp")
print(temp)
然后按照batch为3来分
temp_iter = tf.data.Dataset.from_tensor_slices(temp).batch(3)
for element in temp_iter:
print(element)
从最后一行可以看出,不够3个,所以只有两个。