将图片分成小块
"""
@Author : zhwzhong
@License : (C) Copyright 2013-2018, hit
@Contact : [email protected]
@Software: PyCharm
@File : ImgToPatch.py
@Time : 2019/1/16 15:32
@Desc :
"""
import cv2
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
def img_to_patch(degrade_img, gt_img, patch_size, stride, scale):
"""
:param degrade_img: 退化的图像
:param gt_img: Ground Truth Image
:param patch_size: 块大小
:param stride: 步长
:param scale: scale=1, 可以是去噪、去雾等任务,scale=4可以是超分辨
:return: (patch_num, H, W, C)
"""
height, width = degrade_img.shape[0], degrade_img.shape[1]
pat_row_num = list(range(0, height - patch_size, stride))
if patch_size != stride:
pat_row_num.append(height - patch_size)
pat_col_num = list(range(0, width - patch_size, stride))
if patch_size != stride:
pat_col_num.append(width - patch_size)
degrade_patches = np.zeros(
shape=(len(pat_col_num) * len(pat_row_num), patch_size, patch_size, degrade_img.shape[2]),
dtype=degrade_img.dtype)
gt_patches = np.zeros(
shape=(len(pat_col_num) * len(pat_row_num), patch_size * scale, patch_size * scale, gt_img.shape[2]),
dtype=gt_img.dtype
)
num = 0
for i in pat_row_num:
for j in pat_col_num:
up = i
down = up + patch_size
left = j
right = left + patch_size
degrade_patches[num] = degrade_img[up: down, left: right, :]
hr_up = up * scale
hr_down = hr_up + scale * patch_size
hr_left = left * scale
hr_right = hr_left + scale * patch_size
gt_patches[num] = gt_img[hr_up: hr_down, hr_left: hr_right, :]
num += 1
return degrade_patches, gt_patches
if __name__ == '__main__':
hr_img = np.array(sio.loadmat('lenna.mat')['img'])
lr_img = cv2.resize(cv2.resize(hr_img, (128, 128), interpolation=cv2.INTER_CUBIC), (512, 512),
interpolation=cv2.INTER_CUBIC)
lr_patches, hr_patches = img_to_patch(lr_img, hr_img, 32, 16, 1)
test_num = np.random.randint(0, lr_patches.shape[0], 16)
plt.figure()
for i in range(16):
plt.subplot(4, 4, i + 1)
plt.imshow(lr_patches[test_num[i]])
plt.axis('off')
plt.subplots_adjust(bottom=0, top=1, left=0.1, right=0.9, hspace=0.1, wspace=0)
plt.savefig('lr.png')
plt.figure()
for i in range(16):
plt.subplot(4, 4, i + 1)
plt.imshow(hr_patches[test_num[i]])
plt.axis('off')
plt.subplots_adjust(bottom=0, top=1, left=0.1, right=0.9, hspace=0.1, wspace=0)
plt.savefig('hr.png')
plt.show()