forked from XPixelGroup/BasicSR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sr_model.py
207 lines (174 loc) · 7.63 KB
/
sr_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import importlib
import torch
from collections import OrderedDict
from copy import deepcopy
from os import path as osp
from tqdm import tqdm
from basicsr.models.archs import define_network
from basicsr.models.base_model import BaseModel
from basicsr.utils import get_root_logger, imwrite, tensor2img
loss_module = importlib.import_module('basicsr.models.losses')
metric_module = importlib.import_module('basicsr.metrics')
class SRModel(BaseModel):
"""Base SR model for single image super-resolution."""
def __init__(self, opt):
super(SRModel, self).__init__(opt)
# define network
self.net_g = define_network(deepcopy(opt['network_g']))
self.net_g = self.model_to_device(self.net_g)
self.print_network(self.net_g)
# load pretrained models
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
self.load_network(self.net_g, load_path,
self.opt['path'].get('strict_load_g', True))
if self.is_train:
self.init_training_settings()
def init_training_settings(self):
self.net_g.train()
train_opt = self.opt['train']
# define losses
if train_opt.get('pixel_opt'):
pixel_type = train_opt['pixel_opt'].pop('type')
cri_pix_cls = getattr(loss_module, pixel_type)
self.cri_pix = cri_pix_cls(**train_opt['pixel_opt']).to(
self.device)
else:
self.cri_pix = None
if train_opt.get('perceptual_opt'):
percep_type = train_opt['perceptual_opt'].pop('type')
cri_perceptual_cls = getattr(loss_module, percep_type)
self.cri_perceptual = cri_perceptual_cls(
**train_opt['perceptual_opt']).to(self.device)
else:
self.cri_perceptual = None
if self.cri_pix is None and self.cri_perceptual is None:
raise ValueError('Both pixel and perceptual losses are None.')
# set up optimizers and schedulers
self.setup_optimizers()
self.setup_schedulers()
def setup_optimizers(self):
train_opt = self.opt['train']
optim_params = []
for k, v in self.net_g.named_parameters():
if v.requires_grad:
optim_params.append(v)
else:
logger = get_root_logger()
logger.warning(f'Params {k} will not be optimized.')
optim_type = train_opt['optim_g'].pop('type')
if optim_type == 'Adam':
self.optimizer_g = torch.optim.Adam(optim_params,
**train_opt['optim_g'])
else:
raise NotImplementedError(
f'optimizer {optim_type} is not supperted yet.')
self.optimizers.append(self.optimizer_g)
def feed_data(self, data):
self.lq = data['lq'].to(self.device)
if 'gt' in data:
self.gt = data['gt'].to(self.device)
def optimize_parameters(self, current_iter):
self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq)
l_total = 0
loss_dict = OrderedDict()
# pixel loss
if self.cri_pix:
l_pix = self.cri_pix(self.output, self.gt)
l_total += l_pix
loss_dict['l_pix'] = l_pix
# perceptual loss
if self.cri_perceptual:
l_percep, l_style = self.cri_perceptual(self.output, self.gt)
if l_percep is not None:
l_total += l_percep
loss_dict['l_percep'] = l_percep
if l_style is not None:
l_total += l_style
loss_dict['l_style'] = l_style
l_total.backward()
self.optimizer_g.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
def test(self):
self.net_g.eval()
with torch.no_grad():
self.output = self.net_g(self.lq)[:,0:3,:,:] # keep first stage is real output
self.net_g.train()
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
logger = get_root_logger()
logger.info('Only support single GPU validation.')
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
def nondist_validation(self, dataloader, current_iter, tb_logger,
save_img):
dataset_name = dataloader.dataset.opt['name']
with_metrics = self.opt['val'].get('metrics') is not None
if with_metrics:
self.metric_results = {
metric: 0
for metric in self.opt['val']['metrics'].keys()
}
pbar = tqdm(total=len(dataloader), unit='image')
for idx, val_data in enumerate(dataloader):
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
self.feed_data(val_data)
self.test()
visuals = self.get_current_visuals()
sr_img = tensor2img([visuals['result']])
if 'gt' in visuals:
gt_img = tensor2img([visuals['gt']])
del self.gt
# tentative for out of GPU memory
del self.lq
del self.output
torch.cuda.empty_cache()
if save_img:
if self.opt['is_train']:
save_img_path = osp.join(self.opt['path']['visualization'],
img_name,
f'{img_name}_{current_iter}.png')
else:
if self.opt['val']['suffix']:
save_img_path = osp.join(
self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["val"]["suffix"]}.png')
else:
save_img_path = osp.join(
self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["name"]}.png')
imwrite(sr_img, save_img_path)
if with_metrics:
# calculate metrics
opt_metric = deepcopy(self.opt['val']['metrics'])
for name, opt_ in opt_metric.items():
metric_type = opt_.pop('type')
self.metric_results[name] += getattr(
metric_module, metric_type)(sr_img, gt_img, **opt_)
pbar.update(1)
pbar.set_description(f'Test {img_name}')
pbar.close()
if with_metrics:
for metric in self.metric_results.keys():
self.metric_results[metric] /= (idx + 1)
self._log_validation_metric_values(current_iter, dataset_name,
tb_logger)
def _log_validation_metric_values(self, current_iter, dataset_name,
tb_logger):
log_str = f'Validation {dataset_name}\n'
for metric, value in self.metric_results.items():
log_str += f'\t # {metric}: {value:.4f}\n'
logger = get_root_logger()
logger.info(log_str)
if tb_logger:
for metric, value in self.metric_results.items():
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
def get_current_visuals(self):
out_dict = OrderedDict()
out_dict['lq'] = self.lq.detach().cpu()
out_dict['result'] = self.output.detach().cpu()
if hasattr(self, 'gt'):
out_dict['gt'] = self.gt.detach().cpu()
return out_dict
def save(self, epoch, current_iter):
self.save_network(self.net_g, 'net_g', current_iter)
self.save_training_state(epoch, current_iter)