-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_single.py
97 lines (73 loc) · 2.74 KB
/
eval_single.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import matplotlib
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42
from functools import partial
from pathlib import Path
import matplotlib.pyplot as plt
import torch
from data.images import CIFAR10_NAME, TINY_IMAGENET_NAME, ImagesDataset
from data.nets import NetsDataset
from models.decoder import Decoder
from models.encoder import Encoder
from models.resnet_fusedbn import ResNetFusedBN
from trainers.classification import ClassificationTrainer
from trainers.utils import progress_bar
device = torch.device("cuda")
# dataset_name = CIFAR10_NAME
dataset_name = TINY_IMAGENET_NAME
dataset = ImagesDataset(dataset_name, batch_size=128)
_, _, test_loader = dataset.get_loaders()
eval_func = partial(ClassificationTrainer.eval_accuracy, images_loader=test_loader, device=device)
test_list = f"/path/to/input/list"
ckpt_file = f"/path/to/netspace/ckpt"
out_net = ResNetFusedBN(0, 2, 8, dataset_name)
# prep_size = (8, 10000)
prep_size = (16, 10000)
emb_size = 4096
save_path = f"images/{dataset_name}/single.pdf"
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
nets_test_dataset = NetsDataset(test_list, device, eval_func, prep_size)
ckpt = torch.load(ckpt_file)
enc = Encoder(emb_size=emb_size)
enc.load_state_dict(ckpt["0"])
enc.to(device)
enc.eval()
dec = Decoder([out_net], emb_size, prep_size)
dec.load_state_dict(ckpt["1"])
dec.to(device)
dec.eval()
target_scores = []
predicted_scores = []
with torch.no_grad():
for i in progress_bar(range(len(nets_test_dataset))):
net, prep = nets_test_dataset[i]
target_scores.append(net.score)
embedding = enc(prep.unsqueeze(0))
predicted_prep = dec(embedding)[0]
predicted_score = eval_func(net, net_prep=predicted_prep)
predicted_scores.append(predicted_score)
min_score = 1000.0
max_score = 0.0
for i in range(len(target_scores)):
min_score = min(min_score, target_scores[i], predicted_scores[i])
max_score = max(max_score, target_scores[i], predicted_scores[i])
fig, ax = plt.subplots(figsize=(6, 3))
ax.set_xlabel("target instance id", fontsize=24)
ax.set_ylabel("accuracy", fontsize=24)
ax.set_ylim(min_score - 3, max_score + 3)
ax.grid(alpha=0.2)
ax.tick_params(axis="both", which="major", labelsize=15)
idx = [n for n in range(len(target_scores))]
ax.scatter(idx, target_scores, c="r", marker="o", s=50, label="target", zorder=2)
ax.scatter(idx, predicted_scores, c="b", marker="+", s=50, label="predicted", zorder=2)
ax.legend(fontsize=18, loc="lower right", handletextpad=0.1)
for i in range(len(target_scores)):
ax.plot(
[i, i],
[target_scores[i], predicted_scores[i]],
linestyle=":",
c="black",
alpha=0.3,
zorder=1,
)
fig.savefig(save_path, bbox_inches="tight", dpi=600)