PyTorch入门(五):数据加载和处理

数据加载和处理
PyTorch提供了许多工具加载数据,使代码更具有可读性。

  • scikit-image:用于图像io和transform
  • pandas:更容易解析csv

我们要处理一个面部姿态的数据集。每张图片有68个不同的标记点。如下图注释:
PyTorch入门(五):数据加载和处理
快速读取csv文件并且从一个(N,2)的数组得到标记,其中N是标记点的数量。

landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')

n = 65
#提取第n行第0列的值,即照片名
img_name = landmarks_frame.iloc[n,0]
#将第n行第1列以后所有的列以矩阵形式显示
landmarks = landmarks_frame.iloc[n,1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)

print('Image name:{}'.format(img_name))
print('Landmarks shape:{}'.format(landmarks.shape))
print('First 4 Landmarks:{}'.format(landmarks[:4]))

输出为
PyTorch入门(五):数据加载和处理
通过一下代码可以显示图像及标记,用它来显示样本。

def show_landmarks(image,landmarks):
    plt.imshow(image)
    plt.scatter(landmarks[:,0],landmarks[:,1],s=10,marker='.',c='r')
    plt.pause(0.001)

plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/',img_name)),landmarks)
plt.show()

PyTorch入门(五):数据加载和处理
DataSet class

torch.utils.data.DataSet是一个表示数据集的抽象类。自定义的数据集应该继承Dataset类并且重载一下方法:

  • len:通过len(dataset)返回数据集的大小
  • getitem:支持整数索引,范围从0到len(self),用法:通过dataset[i]得到索引为i的样本和标签

定制自己的DataSet。首先继承DataSet类,在__init__函数中实现csv数据读入,但读图是在__getitem__中实现,这是一种高效的方法,因为不是所有的数据都要在一开始读入内存中,可以在需要的时候再读取。
我们的数据集是字典形式{'image': image, 'landmarks':landmarks}

class FaceLandmarksDataset(Dataset):
    def __init__(self,csv_file,root_dir,transform=None):
        """
        :param csv_file: 带注释的csv文件路径
        :param root_dir: 所有图像目录
        :param transform: 一个样本要应用的可选变换
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, item):
        img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[item,0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[item,1:].as_matrix()
        landmarks = landmarks.astype('float').reshape(-1,2)
        sample = {'image':image,'landmarks':landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

将该类实例化,并且显示前4个样本及他们的标记点。

face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',root_dir='data/faces/')

fig = plt.figure()

for i in range(len(face_dataset)):
    sample = face_dataset[i]
    print(i,sample['image'].shape,sample['landmarks'].shape)

    ax = plt.subplot(1,4,i+1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)

    if i==3:
        plt.show()
        break

PyTorch入门(五):数据加载和处理
PyTorch入门(五):数据加载和处理