如何在使用Keras的数据生成器时检索标签信息?

问题描述:

而使用以下Keras Python代码:如何在使用Keras的数据生成器时检索标签信息?

for x_batch,y_batch in datagen.flow_from_directory(
    directory = os.path.join(dataset_root_path,dataset_train_path), 
    target_size = (520,520), 
    class_mode = 'binary', 
    batch_size = 1 
): 

我得到了x_batch和y_batch numpy的阵列,所述y_batch numpy的阵列被编码成数0.0或1.0,因为我使用的“二进制” class_mode,但是,通过这种方式,我失去了关于该样本的真实标签的信息,例如“猫”或“狗”。如何根据输出'1.0'和'0.0'检索标签信息?

我建议来实例化datagenerator和列车fit_generator

train_gen = datagen.flow_from_directory(...) 
model.fit_generator(train_gen, ...) 

然后,您可以访问(以及其他属性)的类指数与train_gen.class_indices