如何在Tensorflow中为CTC丢失生成/读取稀疏序列标签?

问题描述:

从他们抄写单词的图像列表,我想使用tf.train.slice_input_producer创建和读取稀疏序列标签(用于tf.nn.ctc_loss),避免如何在Tensorflow中为CTC丢失生成/读取稀疏序列标签?

  1. TFRecord格式序列化预包装的训练数据到磁盘

  2. tf.py_func表观局限性,

  3. 任何不必要或过早填充和

  4. 将整个数据集读取到RAM中。

的主要问题似乎是一个字符串被转换为标签的需要tf.nn.ctc_loss的序列(SparseTensor)。

例如,在字符集(有序)范围[A-Z]中,我想将文本标签字符串"BAD"转换为序列标签类列表[1,0,3]

每个示例图像我想读包含文本的文件名的一部分,所以它是直接的提取和做转化率直线上升蟒蛇。 (如果有一种方法来计算TensorFlow内做到这一点,我还没有发现它。)

以前的几个问题,瞄了一眼这些问题,但我一直没能对他们成功整合。例如,

有没有办法整合这些方法?

另一个例子(SO问题#38012743)显示了我如何推迟从字符串到列表的转换,直到解除文件名出队权之后,但它依赖于tf.py_func,它有一些注意事项。 (我应该担心它们吗?)

我认识到“SparseTensors不能很好地处理队列”(每个tf文档),所以在批处理之前可能需要对结果做一些voodoo(序列化?) ,甚至在计算发生的地方返工;我对此表示欢迎。

按照MarvMind的提纲,这是一个基本框架,包含我想要的计算(遍历包含示例文件名的行,提取每个标签字符串并转换为序列),但是我没有成功确定“Tensorflow” 。

谢谢你正确的“调整”,对我的目标来说是一个更合适的策略,或者指示tf.py_func不会破坏培训效率或下游的其他东西(例如,,加载训练有素的模型以供将来使用)。

编辑(+7小时)我找到了缺少的操作来修补东西了。虽然仍然需要验证这与CTC_Loss下游连接,但我已检查以下编辑的版本是否正确批量并读取图像和稀疏张量。

out_charset="ABCDEFGHIJKLMNOPQRSTUVWXYZ" 

def input_pipeline(data_filename): 
    filenames,seq_labels = _get_image_filenames_labels(data_filename) 
    data_queue = tf.train.slice_input_producer([filenames, seq_labels]) 
    image,label = _read_data_format(data_queue) 
    image,label = tf.train.batch([image,label],batch_size=2,dynamic_pad=True) 
    label = tf.deserialize_many_sparse(label,tf.int32) 
    return image,label 

def _get_image_filenames_labels(data_filename): 
    filenames = [] 
    labels = [] 
    with open(data_filename)) as f: 
     for line in f: 
      # Carve out the ground truth string and file path from 
      # lines formatted like: 
      # ./241/7/158_NETWORK_51375.jpg 51375 
      filename = line.split(' ',1)[0][2:] # split off "./" and number 
      # Extract label string embedded within image filename 
      # between underscores, e.g. NETWORK 
      text = os.path.basename(filename).split('_',2)[1] 
      # Transform string text to sequence of indices using charset, e.g., 
      # NETWORK -> [13, 4, 19, 22, 14, 17, 10] 
      indices = [[i] for i in range(0,len(text))] 
      values = [out_charset.index(c) for c in list(text)] 
      shape = [len(text)] 
      label = tf.SparseTensorValue(indices,values,shape) 
      label = tf.convert_to_tensor_or_sparse_tensor(label) 
      label = tf.serialize_sparse(label) # needed for batching 
      # Add data to lists for conversion 
      filenames.append(filename) 
      labels.append(label) 
    filenames = tf.convert_to_tensor(filenames) 
    labels = tf.convert_to_tensor_or_sparse_tensor(labels) 
    return filenames, labels 

def _read_data_format(data_queue): 
    label = data_queue[1] 
    raw_image = tf.read_file(data_queue[0]) 
    image = tf.image.decode_jpeg(raw_image,channels=1) 
    return image,label 

主要观点似乎是创建从数据需要一个SparseTensorValue,传递给tf.convert_to_tensor_or_sparse_tensor,然后(如果你想批量的数据)tf.serialize_sparse序列化。批处理后,您可以使用tf.deserialize_many_sparse恢复值。

下面是大纲。创建稀疏值,转换为张量,和序列化:

indices = [[i] for i in range(0,len(text))] 
values = [out_charset.index(c) for c in list(text)] 
shape = [len(text)] 
label = tf.SparseTensorValue(indices,values,shape) 
label = tf.convert_to_tensor_or_sparse_tensor(label) 
label = tf.serialize_sparse(label) # needed for batching 

然后,你可以做配料和反序列化:

image,label = tf.train.batch([image,label],dynamic_pad=True) 
label = tf.deserialize_many_sparse(label,tf.int32)