-
Notifications
You must be signed in to change notification settings - Fork 9
/
metric_counter.py
56 lines (45 loc) · 2.08 KB
/
metric_counter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import logging
from collections import defaultdict
import numpy as np
from tensorboardX import SummaryWriter
WINDOW_SIZE = 100
class MetricCounter:
def __init__(self, exp_name):
self.writer = SummaryWriter(exp_name)
logging.basicConfig(filename='{}.log'.format(exp_name), level=logging.DEBUG)
self.metrics = defaultdict(list)
self.images = defaultdict(list)
self.best_metric = 0
def add_image(self, x: np.ndarray, tag: str):
self.images[tag].append(x)
def clear(self):
self.metrics = defaultdict(list)
self.images = defaultdict(list)
def add_losses(self, l_G, l_content, l_D=0):
for name, value in zip(('G_loss', 'G_loss_content', 'G_loss_adv', 'D_loss'),
(l_G, l_content, l_G - l_content, l_D)):
self.metrics[name].append(value)
def add_metrics(self, psnr, ssim):
for name, value in zip(('PSNR', 'SSIM'),
(psnr, ssim)):
self.metrics[name].append(value)
def loss_message(self):
metrics = ((k, np.mean(self.metrics[k][-WINDOW_SIZE:])) for k in ('G_loss', 'PSNR', 'SSIM'))
return '; '.join(map(lambda x: f'{x[0]}={x[1]:.4f}', metrics))
def write_to_tensorboard(self, epoch_num, validation=False):
scalar_prefix = 'Validation' if validation else 'Train'
for tag in ('G_loss', 'D_loss', 'G_loss_adv', 'G_loss_content', 'SSIM', 'PSNR'):
self.writer.add_scalar(f'{scalar_prefix}_{tag}', np.mean(self.metrics[tag]), global_step=epoch_num)
for tag in self.images:
imgs = self.images[tag]
if imgs:
imgs = np.array(imgs)
self.writer.add_images(tag, imgs[:, :, :, ::-1].astype('float32') / 255, dataformats='NHWC',
global_step=epoch_num)
self.images[tag] = []
def update_best_model(self):
cur_metric = np.mean(self.metrics['PSNR'])
if self.best_metric < cur_metric:
self.best_metric = cur_metric
return True
return False