-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
63 lines (50 loc) · 2.1 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from sklearn.metrics import mean_squared_error
from skimage.metrics._structural_similarity import structural_similarity as ssim
# from skimage.measure import compare_ssim as ssim
import torch
import numpy as np
import math
import lpips
def MSE(y_input, y_target):
N, C, H, W = y_input.shape
assert(C == 1 or C == 3)
sum_mse_over_batch = 0.
for i in range(N):
sum_mse_over_batch += mean_squared_error(
y_input[i, 0, :, :][~np.isnan(y_target[i, 0, :, :])], y_target[i, 0, :, :][~np.isnan(y_target[i, 0, :, :])])
if C == 3: # color
sum_mse_over_batch += mean_squared_error(
y_input[i, 1, :, :][~np.isnan(y_target[i, 1, :, :])], y_target[i, 1, :, :][~np.isnan(y_target[i, 1, :, :])])
sum_mse_over_batch += mean_squared_error(
y_input[i, 2, :, :][~np.isnan(y_target[i, 2, :, :])], y_target[i, 2, :, :][~np.isnan(y_target[i, 2, :, :])])
mean_mse = sum_mse_over_batch / (float(N))
if C == 3:
mean_mse /= 3.0
return mean_mse
def SSIM(y_input, y_target):
N, C, H, W = y_input.shape
assert(C == 1 or C == 3)
# N x C x H x W -> N x W x H x C -> N x H x W x C
y_input = np.swapaxes(y_input, 1, 3)
y_input = np.swapaxes(y_input, 1, 2)
y_target = np.swapaxes(y_target, 1, 3)
y_target = np.swapaxes(y_target, 1, 2)
sum_structural_similarity_over_batch = 0.
for i in range(N):
if C == 3:
sum_structural_similarity_over_batch += ssim(
y_input[i, :, :, :], y_target[i, :, :, :], multichannel=True)
else:
sum_structural_similarity_over_batch += ssim(
y_input[i, :, :, 0], y_target[i, :, :, 0])
return sum_structural_similarity_over_batch / float(N)
def PSNR(y_input, y_target):
mse_output = np.mean( (y_input - y_target) ** 2 )
if mse_output == 0:
return 100
PIXEL_MAX = 1.0
return 20 * math.log10(PIXEL_MAX / math.sqrt(mse_output))
def LPIPS(y_input, y_target):
loss_fn_alex = lpips.LPIPS(net='alex', verbose=False)
lpips_output = loss_fn_alex(y_input, y_target)
return lpips_output