tf.data迭代器问题

import tensorflow as tf 
import numpy as np

def __a(a):
    b=a+1
    b=np.squeeze(b)
    return b 

a=np.array(range(5))
b=tf.constant(a)
dataset =  tf.data.Dataset.from_tensor_slices(a)
dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=len(a), count=-1))
dataset= dataset.apply(tf.contrib.data.map_and_batch(
                map_func=lambda c: tf.py_func(__a, [c], [tf.int64]),
                batch_size=1))

print (a)
print ('wwwwwwwwwwwww')
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
    for q in range(5):
        for i in range(5):
            value = sess.run(next_element)
            print (value)
            print ('xxx')
        print ('qqq')

在 for i in range(5):的时候,没问题,每一次大迭代都会遍历a中的元素,也就是0~4。

tf.data迭代器问题

但是把这句话改为for i in range(2):的时候,就会变成如下图

tf.data迭代器问题 

也就是说前5次迭代还是会遍历0~4。