-
Notifications
You must be signed in to change notification settings - Fork 2
/
test_dnd_nam.py
92 lines (74 loc) · 2.78 KB
/
test_dnd_nam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python
import os
import glob
import cv2
import argparse
import numpy as np
from models import *
from torchvision.utils import save_image, make_grid
parser = argparse.ArgumentParser()
parser.add_argument('--nGPU', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--inp', type=str, default='testsets/DND_20_rand_patches', help='input folder')
parser.add_argument('--out', type=str, default='results', help='output folder')
parser.add_argument('--JPEG', action='store_true', help="for JPEG images")
opt = parser.parse_args()
print(opt)
# Number of GPUs available. Use 0 for CPU mode.
ngpu = opt.nGPU
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
# Create models
unetnd = UNet_ND().to(device)
unetd = UNet_D().to(device)
higan = HI_GAN().to(device)
# Load models
if opt.JPEG:
unetnd_path = 'models/unet_nd_jpeg.pth'
unetd_path = 'models/unet_d_jpeg.pth'
higan_path = 'models/hi_gan_jpeg.pth'
else:
unetnd_path = 'models/unet_nd.pth'
unetd_path = 'models/unet_d.pth'
higan_path = 'models/hi_gan.pth'
if (device.type == 'cuda') and (ngpu >= 1):
unetnd = nn.DataParallel(unetnd, list(range(ngpu)))
unetd = nn.DataParallel(unetd, list(range(ngpu)))
higan = nn.DataParallel(higan, list(range(ngpu)))
unetnd.load_state_dict(torch.load(unetnd_path), strict=False)
unetd.load_state_dict(torch.load(unetd_path), strict=False)
higan.load_state_dict(torch.load(higan_path))
# Denoise
print('\n> Test set', opt.inp)
files = []
types = ('*.bmp', '*.png', '*.jpg', '*.JPEG', '*.tif')
for tp in types:
files.extend(glob.glob(os.path.join(opt.inp, tp)))
files.sort()
for i, item in enumerate(files):
torch.cuda.empty_cache()
print("\tfile: %s" % item)
img_folder = os.path.basename(os.path.dirname(item))
img_name = os.path.basename(item)
img_name = os.path.splitext(img_name)[0]
# Read img
imorig = cv2.imread(item)
imorig = imorig[:, :, ::-1] / 255.0
imorig = np.array(imorig).astype('float32')
imorig = np.expand_dims(imorig.transpose(2, 0, 1), 0)
imorig = torch.Tensor(imorig).to(device)
with torch.no_grad():
gt_dn = unetnd(imorig)
gf_dn = unetd(imorig)
higan_dn = higan(gf_dn, gt_dn)
# save by save_image
save_img_dir = os.path.join(opt.out, img_folder)
# create result folder
try:
os.makedirs(os.path.join(opt.out, img_folder))
except OSError:
pass
if opt.JPEG:
save_image(make_grid(higan_dn.clamp(0., 1.), nrow=8, normalize=False, scale_each=False),
'%s/%s_HIGAN_JPEG_denoi.png' % (save_img_dir, img_name))
else:
save_image(make_grid(higan_dn.clamp(0., 1.), nrow=8, normalize=False, scale_each=False),
'%s/%s_HIGAN_denoi.png' % (save_img_dir, img_name))