Spaces:
Build error
Build error
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from esrgan import ESRGAN | |
| esrgan = ESRGAN() | |
| def viz(module, input): | |
| x = input[0][0] | |
| #最多显示4张图 | |
| min_num = np.minimum(4, x.size()[0]) | |
| for i in range(min_num): | |
| plt.subplot(1, 4, i+1) | |
| plt.imshow(x[i].cpu()) | |
| plt.xticks([]) #去掉横坐标值 | |
| plt.yticks([]) #去掉纵坐标值 | |
| plt.show() | |
| def main(): | |
| t = transforms.Compose([transforms.ToPILImage(), | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| model = esrgan.net | |
| for name, m in model.named_modules(): | |
| # if not isinstance(m, torch.nn.ModuleList) and \ | |
| # not isinstance(m, torch.nn.Sequential) and \ | |
| # type(m) in torch.nn.__dict__.values(): | |
| # 这里只对卷积层的feature map进行显示 | |
| if isinstance(m, torch.nn.Conv2d): | |
| m.register_forward_pre_hook(viz) | |
| img = cv2.imread('image.png') | |
| img = t(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| model(img) | |
| if __name__ == '__main__': | |
| main() | |