-
Notifications
You must be signed in to change notification settings - Fork 1
/
filewise_export_stats.py
85 lines (63 loc) · 2.51 KB
/
filewise_export_stats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import json
import torch
import sys
from skimage import io
from losses.psnr import psnr
from pytorch_msssim import ms_ssim, ssim
from serialization import deserialize_state_dict
from utils import calculate_state_dict_size, load_device
def ms_ssim_reshape(tensor):
return tensor.movedim(-1, 0).unsqueeze(0)
def main():
print("Loading device...")
device = load_device(True)
print("Loading parameters...")
original_file_path = sys.argv[1]
reconstructed_file_path = sys.argv[2]
stats_path = sys.argv[3]
compressed_file_path = sys.argv[4]
print("Calculating compressed state size...")
compressed_file_size = os.stat(compressed_file_path).st_size
print("Loading images...")
original_image_tensor = torch.from_numpy(io.imread(original_file_path)).to(device).to(torch.float32)
reconstructed_image_tensor = torch.from_numpy(io.imread(reconstructed_file_path)).to(device).to(torch.float32)
pixels = original_image_tensor.nelement() / 3.0
print("Calculating stats...")
try:
ssim_value = ssim(ms_ssim_reshape(original_image_tensor), ms_ssim_reshape(reconstructed_image_tensor)).item()
except Exception as e:
print(f"Cannot calculate SSIM: {e}")
ssim_value = None
try:
ms_ssim_value = ms_ssim(ms_ssim_reshape(original_image_tensor), ms_ssim_reshape(reconstructed_image_tensor)).item()
except Exception as e:
print(f"Cannot calculate MS-SSIM: {e}")
ms_ssim_value = None
try:
compressed_state_dict = deserialize_state_dict(compressed_file_path)
except Exception as e:
print(f"WARNING: Cannot deserialize compressed state dict: {e}")
compressed_state_dict = None
try:
if compressed_state_dict is None:
compressed_state_dict = torch.load(compressed_file_path)
compressed_state_size = calculate_state_dict_size(compressed_state_dict)
except Exception as e:
print(f"WARNING: Cannot calculate state-only bpp: {e}")
compressed_state_size = 0
stats = {
"psnr": psnr(original_image_tensor, reconstructed_image_tensor).item(),
"ms-ssim": ms_ssim_value,
"ssim": ssim_value,
"bpp": (compressed_file_size * 8) / pixels,
"state_bpp": (compressed_state_size * 8) / pixels
}
try:
stats["state_size"] = calculate_state_dict_size(torch.load(compressed_file_path))
except Exception as e:
pass
print(stats)
json.dump(stats, open(stats_path, "w"))
if __name__ == "__main__":
main()