Spaces:
Build error
Build error
| import itertools | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from torch.nn import functional as F | |
| # import cv2 | |
| import distutils.util | |
| def show_result(num_epoch, G_net, imgs_lr, imgs_hr): | |
| with torch.no_grad(): | |
| test_images = G_net(imgs_lr) | |
| fig, ax = plt.subplots(1, 2) | |
| for j in itertools.product(range(2)): | |
| ax[j].get_xaxis().set_visible(False) | |
| ax[j].get_yaxis().set_visible(False) | |
| ax[0].cla() | |
| ax[0].imshow(np.transpose(test_images.cpu().numpy()[0] * 0.5 + 0.5, [1,2,0])) | |
| ax[1].cla() | |
| ax[1].imshow(np.transpose(imgs_hr.cpu().numpy()[0] * 0.5 + 0.5, [1,2,0])) | |
| label = 'Epoch {0}'.format(num_epoch) | |
| fig.text(0.5, 0.04, label, ha='center') | |
| plt.savefig("results/train_out/epoch_" + str(num_epoch) + "_results.png") | |
| plt.close('all') #避免内存泄漏 | |
| #---------------------------------------------------------# | |
| # 将图像转换成RGB图像,防止灰度图在预测时报错。 | |
| # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB | |
| #---------------------------------------------------------# | |
| def cvtColor(image): | |
| if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: | |
| return image | |
| else: | |
| image = image.convert('RGB') | |
| return image | |
| def preprocess_input(image, mean, std): | |
| image = (image/255 - mean)/std | |
| return image | |
| def get_lr(optimizer): | |
| for param_group in optimizer.param_groups: | |
| return param_group['lr'] | |
| def print_arguments(args): | |
| print("----------- Configuration Arguments -----------") | |
| for arg, value in sorted(vars(args).items()): | |
| print("%s: %s" % (arg, value)) | |
| print("------------------------------------------------") | |
| def add_arguments(argname, type, default, help, argparser, **kwargs): | |
| type = distutils.util.strtobool if type == bool else type | |
| argparser.add_argument("--" + argname, | |
| default=default, | |
| type=type, | |
| help=help + ' 默认: %(default)s.', | |
| **kwargs) | |
| def filter2D(img, kernel): | |
| """PyTorch version of cv2.filter2D | |
| Args: | |
| img (Tensor): (b, c, h, w) | |
| kernel (Tensor): (b, k, k) | |
| """ | |
| k = kernel.size(-1) | |
| b, c, h, w = img.size() | |
| if k % 2 == 1: | |
| img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') | |
| else: | |
| raise ValueError('Wrong kernel size') | |
| ph, pw = img.size()[-2:] | |
| if kernel.size(0) == 1: | |
| # apply the same kernel to all batch images | |
| img = img.view(b * c, 1, ph, pw) | |
| kernel = kernel.view(1, 1, k, k) | |
| return F.conv2d(img, kernel, padding=0).view(b, c, h, w) | |
| else: | |
| img = img.view(1, b * c, ph, pw) | |
| kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) | |
| return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) | |
| def usm_sharp(img, weight=0.5, radius=50, threshold=10): | |
| """USM sharpening. | |
| Input image: I; Blurry image: B. | |
| 1. sharp = I + weight * (I - B) | |
| 2. Mask = 1 if abs(I - B) > threshold, else: 0 | |
| 3. Blur mask: | |
| 4. Out = Mask * sharp + (1 - Mask) * I | |
| Args: | |
| img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. | |
| weight (float): Sharp weight. Default: 1. | |
| radius (float): Kernel size of Gaussian blur. Default: 50. | |
| threshold (int): | |
| """ | |
| if radius % 2 == 0: | |
| radius += 1 | |
| blur = cv2.GaussianBlur(img, (radius, radius), 0) | |
| residual = img - blur | |
| mask = np.abs(residual) * 255 > threshold | |
| mask = mask.astype('float32') | |
| soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) | |
| sharp = img + weight * residual | |
| sharp = np.clip(sharp, 0, 1) | |
| return soft_mask * sharp + (1 - soft_mask) * img | |
| class USMSharp(torch.nn.Module): | |
| def __init__(self, radius=50, sigma=0): | |
| super(USMSharp, self).__init__() | |
| if radius % 2 == 0: | |
| radius += 1 | |
| self.radius = radius | |
| kernel = cv2.getGaussianKernel(radius, sigma) | |
| kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) | |
| self.register_buffer('kernel', kernel) | |
| def forward(self, img, weight=0.5, threshold=10): | |
| blur = filter2D(img, self.kernel) | |
| residual = img - blur | |
| mask = torch.abs(residual) * 255 > threshold | |
| mask = mask.float() | |
| soft_mask = filter2D(mask, self.kernel) | |
| sharp = img + weight * residual | |
| sharp = torch.clip(sharp, 0, 1) | |
| return soft_mask * sharp + (1 - soft_mask) * img | |
| class USMSharp_npy(): | |
| def __init__(self, radius=50, sigma=0): | |
| super(USMSharp_npy, self).__init__() | |
| if radius % 2 == 0: | |
| radius += 1 | |
| self.radius = radius | |
| kernel = cv2.getGaussianKernel(radius, sigma) | |
| self.kernel = np.dot(kernel, kernel.transpose()).astype(np.float32) | |
| def filt(self, img, weight=0.5, threshold=10): | |
| blur = cv2.filter2D(img, -1, self.kernel) | |
| residual = img - blur | |
| mask = np.abs(residual) * 255 > threshold | |
| mask = mask.astype(np.float32) | |
| soft_mask = cv2.filter2D(mask, -1, self.kernel) | |
| sharp = img + weight * residual | |
| sharp = np.clip(sharp, 0, 1) | |
| return soft_mask * sharp + (1 - soft_mask) * img | |