forked from yhenon/pytorch-retinanet
-
Notifications
You must be signed in to change notification settings - Fork 2
/
coco_eval.py
87 lines (64 loc) · 2.64 KB
/
coco_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from __future__ import print_function
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
import numpy as np
import json
import os
import torch
def evaluate_coco(dataset, model, threshold=0.05):
model.eval()
with torch.no_grad():
# start collecting results
results = []
image_ids = []
for index in range(len(dataset)):
data = dataset[index]
scale = data['scale']
# run network
scores, labels, boxes = model(data['img'].permute(2, 0, 1).cuda().float().unsqueeze(dim=0))
scores = scores.cpu()
labels = labels.cpu()
boxes = boxes.cpu()
# correct boxes for image scale
boxes /= scale
if boxes.shape[0] > 0:
# change to (x, y, w, h) (MS COCO standard)
boxes[:, 2] -= boxes[:, 0]
boxes[:, 3] -= boxes[:, 1]
# compute predicted labels and scores
#for box, score, label in zip(boxes[0], scores[0], labels[0]):
for box_id in range(boxes.shape[0]):
score = float(scores[box_id])
label = int(labels[box_id])
box = boxes[box_id, :]
# scores are sorted, so we can break
if score < threshold:
break
# append detection for each positively labeled class
image_result = {
'image_id' : dataset.image_ids[index],
'category_id' : dataset.label_to_coco_label(label),
'score' : float(score),
'bbox' : box.tolist(),
}
# append detection to results
results.append(image_result)
# append image to list of processed images
image_ids.append(dataset.image_ids[index])
# print progress
print('{}/{}'.format(index, len(dataset)), end='\r')
if not len(results):
return
# write output
json.dump(results, open('{}_bbox_results.json'.format(dataset.set_name), 'w'), indent=4)
# load results in COCO evaluation tool
coco_true = dataset.coco
coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(dataset.set_name))
# run COCO evaluation
coco_eval = COCOeval(coco_true, coco_pred, 'bbox')
coco_eval.params.imgIds = image_ids
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
model.train()
return