-
Notifications
You must be signed in to change notification settings - Fork 2
/
run_LR.py
86 lines (60 loc) · 2.66 KB
/
run_LR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os, argparse
import numpy as np
from tqdm import tqdm
from glob import glob
from PIL import Image
import torch
import torch.nn as nn
import utils
from EDSR.edsr import EDSR
from modules import DSN
from adaptive_gridsampler.gridsampler import Downsampler
from skimage.color import rgb2ycbcr
parser = argparse.ArgumentParser(description='Content Adaptive Resampler for Image downscaling')
parser.add_argument('--model_dir', type=str, default='./models', help='path to the pre-trained model')
parser.add_argument('--img_dir', type=str, help='path to the HR images to be downscaled')
parser.add_argument('--scale', type=int, help='downscale factor')
parser.add_argument('--output_dir', type=str, help='path to store results')
parser.add_argument('--benchmark', type=bool, default=True, help='report benchmark results')
args = parser.parse_args()
SCALE = args.scale
KSIZE = 3 * SCALE + 1
OFFSET_UNIT = SCALE
BENCHMARK = args.benchmark
upscale_net = EDSR(32, 256, scale=SCALE).cuda()
upscale_net = nn.DataParallel(upscale_net, [0])
upscale_net.load_state_dict(torch.load(os.path.join(args.model_dir, '{0}x'.format(SCALE), 'usn.pth')))
torch.set_grad_enabled(False)
def validation(img_new, name, save_imgs=False, save_dir=None):
upscale_net.eval()
img = torch.clamp(img_new, 0, 1)
img = torch.round(img * 255)
reconstructed_img = upscale_net(img / 255.0)
img = img_new * 255
img = img.data.cpu().numpy().transpose(0, 2, 3, 1)
img = np.uint8(img)
reconstructed_img = torch.clamp(reconstructed_img, 0, 1) * 255
reconstructed_img = reconstructed_img.data.cpu().numpy().transpose(0, 2, 3, 1)
reconstructed_img = np.uint8(reconstructed_img)
orig_img = img[0, ...].squeeze()
recon_img = reconstructed_img[0, ...].squeeze()
if save_imgs and save_dir:
img = Image.fromarray(orig_img)
img.save(os.path.join(save_dir, name + '_orig.png'))
img = Image.fromarray(recon_img)
img.save(os.path.join(save_dir, name + '_recon.png'))
orig_img_y = rgb2ycbcr(orig_img)[:, :, 0]
recon_img_y = rgb2ycbcr(recon_img)[:, :, 0]
orig_img_y = orig_img_y[SCALE:-SCALE, SCALE:-SCALE, ...]
recon_img_y = recon_img_y[SCALE:-SCALE, SCALE:-SCALE, ...]
if __name__ == '__main__':
img_list = glob(os.path.join(args.img_dir, '**', '*.png'), recursive=True)
assert len(img_list) > 0
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
for img_file in tqdm(img_list):
name = os.path.basename(img_file)
name = os.path.splitext(name)[0]
img_test = utils.load_img(img_file)
validation(img_test, name, save_imgs=True, save_dir=args.output_dir)
print('Done!')