-
Notifications
You must be signed in to change notification settings - Fork 2
/
qat_evaluate.py
50 lines (32 loc) · 1.28 KB
/
qat_evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import util.model_util as model_util
import util.dataset_util as dataset_util
import torch
from model import quantizable_resnet18
from hyperparameters import DEVICE
model_path = r'models\quantized_resnet18.pth'
def load_model(model_class, path):
model = model_class()
model.load_state_dict(torch.load(path))
model.eval()
return model
# Load the quantized model
quantized_model_path = "quantized_resnet18.pth"
loaded_model = load_model(lambda: quantizable_resnet18(10,True), quantized_model_path)
# Evaluate the quantized model on the test dataset
def evaluate_model(model, data_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in data_loader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs, _ = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Test Accuracy: {100 * correct / total:.2f}%')
# Move the model to the appropriate device
loaded_model.to(DEVICE)
train_dataset, test_dataset, train_loader, test_loader = dataset_util.dataset_load()
# Evaluate the model
evaluate_model(loaded_model, test_loader)