数据增强操作

将图片分成小块

# -*- coding: utf-8 -*-
"""
@Author  :   zhwzhong
@License :   (C) Copyright 2013-2018, hit
@Contact :   [email protected]
@Software:   PyCharm
@File    :   ImgToPatch.py
@Time    :   2019/1/16 15:32
@Desc    :
"""
import cv2
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt


def img_to_patch(degrade_img, gt_img, patch_size, stride, scale):
    """

    :param degrade_img: 退化的图像
    :param gt_img: Ground Truth Image
    :param patch_size: 块大小
    :param stride: 步长
    :param scale: scale=1, 可以是去噪、去雾等任务,scale=4可以是超分辨
    :return: (patch_num, H, W, C)
    """
    height, width = degrade_img.shape[0], degrade_img.shape[1]
    pat_row_num = list(range(0, height - patch_size, stride))
    # 如果patch_size == stride 说明需要采集不重复的小块
    if patch_size != stride:
        pat_row_num.append(height - patch_size)

    pat_col_num = list(range(0, width - patch_size, stride))
    if patch_size != stride:
        pat_col_num.append(width - patch_size)

    degrade_patches = np.zeros(
        shape=(len(pat_col_num) * len(pat_row_num), patch_size, patch_size, degrade_img.shape[2]),
        dtype=degrade_img.dtype)

    gt_patches = np.zeros(
        shape=(len(pat_col_num) * len(pat_row_num), patch_size * scale, patch_size * scale, gt_img.shape[2]),
        dtype=gt_img.dtype
    )
    num = 0
    for i in pat_row_num:
        for j in pat_col_num:
            up = i
            down = up + patch_size
            left = j
            right = left + patch_size

            degrade_patches[num] = degrade_img[up: down, left: right, :]

            hr_up = up * scale
            hr_down = hr_up + scale * patch_size
            hr_left = left * scale
            hr_right = hr_left + scale * patch_size

            gt_patches[num] = gt_img[hr_up: hr_down, hr_left: hr_right, :]
            num += 1

    return degrade_patches, gt_patches


if __name__ == '__main__':
    hr_img = np.array(sio.loadmat('lenna.mat')['img'])
    lr_img = cv2.resize(cv2.resize(hr_img, (128, 128), interpolation=cv2.INTER_CUBIC), (512, 512),
                        interpolation=cv2.INTER_CUBIC)

    lr_patches, hr_patches = img_to_patch(lr_img, hr_img, 32, 16, 1)
    test_num = np.random.randint(0, lr_patches.shape[0], 16)
    plt.figure()
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow(lr_patches[test_num[i]])
        plt.axis('off')     # 关闭坐标轴
        # plt.tight_layout()  # 使得子图紧凑
    plt.subplots_adjust(bottom=0, top=1, left=0.1, right=0.9, hspace=0.1, wspace=0)  # 调整子图间距
    # plt.title('LR')
    plt.savefig('lr.png')
    plt.figure()
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow(hr_patches[test_num[i]])
        plt.axis('off')     # 关闭坐标轴
        # plt.tight_layout()  # 使得子图紧凑
    plt.subplots_adjust(bottom=0, top=1, left=0.1, right=0.9, hspace=0.1, wspace=0)  # 调整子图间距
    # plt.title('HR')
    plt.savefig('hr.png')
    plt.show()

数据增强操作
数据增强操作