-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
100 lines (81 loc) · 4.04 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import torch
import argparse
import numpy as np
from loguru import logger
from lib.data_loader import get_data_loader
from lib.training.common import test_acc
from lib.models.networks import get_model, get_tokenizer
from lib.inference.base import get_base_score, get_energy_score, get_d2u_score, get_maha_score, get_iflp_score, get_knn_score, get_km_score, get_maxlogit_score, get_pout_score, get_out_score
from lib.inference.godin import searchGeneralizedOdinParameters, get_ODIN_score
from lib.inference.lof import get_lof_score
from lib.inference.dropout import get_dropout_score
from lib.metrics import get_metrics
from lib.exp import get_num_labels
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', help='training dataset')
parser.add_argument('--eval_type', default='ood',
type=str, choices=['acc', 'ood'])
parser.add_argument('--ood_method', default='base', type=str)
parser.add_argument('--ood_datasets',
type=str, required=False)
parser.add_argument('--batch_size', default=32, type=int,
required=False, help='batch size')
parser.add_argument('--model', default='roberta-base',
help='pretrained model type')
parser.add_argument('--pretrained_model', default=None,
type=str, required=False, help='the path of the checkpoint to load')
parser.add_argument('--log_file', type=str, default='./log/default.log')
parser.add_argument('--input_dir', type=str, default=None)
args = parser.parse_args()
log_file_name = args.log_file
logger.add(log_file_name)
logger.info('args:\n' + args.__repr__())
num_labels = get_num_labels(args.dataset)
args.num_labels = num_labels
model = get_model(args)
logger.info("{} model loaded".format(args.model))
if args.pretrained_model:
model.load_state_dict(torch.load(args.pretrained_model))
logger.info("model loaded from {}".format(args.pretrained_model))
tokenizer = get_tokenizer(args.model)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("{} tokenizer loaded".format(args.model))
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model.to(device)
ood_datasets = args.ood_datasets.split(',')
ind_test_loader = get_data_loader(
args.dataset, 'test', tokenizer, args.batch_size)
ood_test_loaders = [get_data_loader(
ood_dataset, 'test', tokenizer, args.batch_size) for ood_dataset in ood_datasets]
wiki_loader = get_data_loader('wiki', 'test', tokenizer, args.batch_size)
if args.eval_type == 'acc': # Classification Validation
acc = test_acc(model, ind_test_loader)
logger.info("Test Accuracy on {} test set: {:.4f}".format(
args.dataset, acc))
else: # OOD Detection Validation
ood_scores_list = []
ood_metrics_list = []
if args.ood_method == 'flats':
ind_scores, ood_scores_list = get_knn_score(args.ood_datasets, args.input_dir)
else:
raise NotImplementedError
_ind_scores, _ood_scores_list = get_out_score(model,wiki_loader, args.ood_datasets, args.input_dir)
ind_scores = np.array([score1 - score2 * 0.5 for score1, score2 in zip(ind_scores, _ind_scores)])
for i in range(len(ood_scores_list)):
ood_scores_list[i] = np.array([score1 - score2 * 0.5 for score1, score2 in zip(ood_scores_list[i], _ood_scores_list[i])])
for i, ood_dataset in enumerate(ood_datasets):
logger.info("OOD: {}".format(ood_dataset))
metrics = get_metrics(ind_scores, ood_scores_list[i])
ood_metrics_list.append(metrics)
logger.info(str(metrics))
mean_metrics = {}
for k, v in metrics.items():
mean_metrics[k] = sum(
[m[k] for m in ood_metrics_list])/len(ood_metrics_list)
logger.info('mean metrics: {}'.format(mean_metrics))
logger.info('evaluation finished')
if __name__ == '__main__':
main()