如何在使用Keras的数据生成器时检索标签信息?
问题描述:
而使用以下Keras Python代码:如何在使用Keras的数据生成器时检索标签信息?
for x_batch,y_batch in datagen.flow_from_directory(
directory = os.path.join(dataset_root_path,dataset_train_path),
target_size = (520,520),
class_mode = 'binary',
batch_size = 1
):
我得到了x_batch和y_batch numpy的阵列,所述y_batch numpy的阵列被编码成数0.0或1.0,因为我使用的“二进制” class_mode,但是,通过这种方式,我失去了关于该样本的真实标签的信息,例如“猫”或“狗”。如何根据输出'1.0'和'0.0'检索标签信息?
答
我建议来实例化datagenerator和列车fit_generator
:
train_gen = datagen.flow_from_directory(...)
model.fit_generator(train_gen, ...)
然后,您可以访问(以及其他属性)的类指数与train_gen.class_indices
。