pytorch—使用 torchvision 的 Transform 读取图片数据(一)
运行环境安装 Anaconda | python ==3.6.6
conda install pytorch -c pytorch
pip install config
pip install tqdm #包装迭代器,显示进度条
pip install torchvision
pip install scikit-image
一、torchvision 图像数据读取 [0, 1]
import torchvision.transforms as transforms
transforms 模块提供了一般的图像转换操作类。class torchvision.transforms.ToTensor
功能:
把shape=(H x W x C) 的像素值为 [0, 255] 的 PIL.Image 和 numpy.ndarray
转换成shape=(C x H x W)的像素值范围为[0.0, 1.0]
的 torch.FloatTensor。
class torchvision.transforms.Normalize(mean, std)
功能:
此转换类作用于torch.*Tensor。给定均值(R, G, B)和标准差(R, G, B),用公式channel = (channel - mean) / std进行规范化。
import torchvision
import torchvision.transforms as transforms
import cv2
import numpy as np
from PIL import Image
img_path = "./data/timg.jpg"
# 引入transforms.ToTensor()功能: range [0, 255] -> [0.0,1.0]
transform1 = transforms.Compose([transforms.ToTensor()])
# 直接读取:numpy.ndarray
img = cv2.imread(img_path)
print("img = ", img[0]) #只输出其中一个通道
print("img.shape = ", img.shape)
# 归一化,转化为numpy.ndarray并显示
img1 = transform1(img)
img2 = img1.numpy()*255
img2 = img2.astype('uint8')
img2 = np.transpose(img2 , (1,2,0))
print("img1 = ", img1)
cv2.imshow('img2 ', img2 )
cv2.waitKey()
# PIL 读取图像
img = Image.open(img_path).convert('RGB') # 读取图像
img2 = transform1(img) # 归一化到 [0.0,1.0]
print("img2 = ",img2) #转化为PILImage并显示
img_2 = transforms.ToPILImage()(img2).convert('RGB')
print("img_2 = ",img_2)
img_2.show()
从上到下依次输出:---------------------------------------------
img = [[197 203 202]
[195 203 202]
...
[200 208 207]
[200 208 207]]
img.shape = (362, 434, 3)
img1 = tensor([[[0.7725, 0.7647, 0.7686, ..., 0.7804, 0.7843, 0.7843],
[0.7765, 0.7725, 0.7686, ..., 0.7686, 0.7608, 0.7569],
[0.7843, 0.7725, 0.7686, ..., 0.7725, 0.7686, 0.7569],
...,
img_transform = tensor([[[0.7922, 0.7922, 0.7961, ..., 0.8078, 0.8118, 0.8118],
[0.7961, 0.8000, 0.7961, ..., 0.7922, 0.7882, 0.7843],
[0.8039, 0.8000, 0.7961, ..., 0.8118, 0.8039, 0.7922],
...,
transforms.Compose 归一化到 [-1.0, 1.0 ]
transform2 = transforms.Compose([transforms.ToTensor()])
transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))])
二、torchvision 的 Transform 图片读取类
在深度学习时关于图像的数据读取:由于Tensorflow不支持与numpy的无缝切换,导致难以使用现成的pandas等格式化数据读取工具,造成了很多不必要的麻烦,而pytorch解决了这个问题。
pytorch自定义读取数据和进行Transform的部分请见文档:
http://pytorch.org/tutorials/beginner/data_loading_tutorial.html
但是按照文档中所描述所完成的自定义Dataset只能够使用自定义的Transform步骤,而torchvision包中已经给我们提供了很多图像transform步骤的实现,为了使用这些已经实现的Transform步骤,我们可以使用如下方法定义Dataset:
from __future__ import print_function, division
import os
import torch
import pandas as pd
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class FaceLandmarkDataset(Dataset):
def __len__(self) -> int:
return len(self.landmarks_frame)
def __init__(self, csv_file: str, root_dir: str, transform=None) -> None:
super().__init__()
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __getitem__(self, index:int):
img_name = self.landmarks_frame.ix[index, 0]
img_path = os.path.join('./faces', img_name)
with Image.open(img_path) as img:
image = img.convert('RGB')
landmarks = self.landmarks_frame.as_matrix()[index, 1:].astype('float')
landmarks = np.reshape(landmarks,newshape=(-1,2))
if self.transform is not None:
image = self.transform(image)
return image, landmarks
########################以上为数据读取类(返回:image,landmarks)###############################
trans = transforms.Compose(transforms = [transforms.RandomSizedCrop(size=128),
transforms.ToTensor()])
face_dataset = FaceLandmarkDataset(csv_file='faces/face_landmarks.csv',
root_dir='faces', transform= trans)
loader = DataLoader(dataset = face_dataset,
batch_size=4,
shuffle=True,
num_workers=4)
鸣谢
https://www.cnblogs.com/denny402/p/5096001.html
https://blog.csdn.net/VictoriaW/article/details/72822005
https://blog.csdn.net/hao5335156/article/details/80593349