tensorflow中的batch数据打印

之前一直困惑怎么打印batch返回值,后来发现一个迭代就可以了,因为这个batch返回的就是迭代器。
首先 batch 就是把数据分成几组,前面己组的数据大小都是batch的大小,但是最后一组可以能数据量不多,就只有一点

如下,先创建一个20行3列的数据
temp = tf.Variable(tf.random.normal(shape=(20, 3), mean=0, stddev=0.01, dtype=tf.float32))
print("temp")
print(temp)

tensorflow中的batch数据打印

然后按照batch为3来分
temp_iter = tf.data.Dataset.from_tensor_slices(temp).batch(3)
for element in temp_iter:
    print(element)
 

tensorflow中的batch数据打印

tensorflow中的batch数据打印

从最后一行可以看出,不够3个,所以只有两个。