Skip to content

Commit 709cc89

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Cleaned-up coco evaluation code (#4453)
Reviewed By: datumbox Differential Revision: D31268048 fbshipit-source-id: e209ac9447e172972baeb322d592de15e4806383
1 parent 7c021da commit 709cc89

File tree

1 file changed

+15
-172
lines changed

1 file changed

+15
-172
lines changed

references/detection/coco_eval.py

Lines changed: 15 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
1-
import json
1+
import copy
2+
import io
3+
from contextlib import redirect_stdout
24

35
import numpy as np
4-
import copy
6+
import pycocotools.mask as mask_util
57
import torch
6-
import torch._six
7-
88
from pycocotools.cocoeval import COCOeval
99
from pycocotools.coco import COCO
10-
import pycocotools.mask as mask_util
11-
12-
from collections import defaultdict
1310

1411
import utils
1512

1613

17-
class CocoEvaluator(object):
14+
class CocoEvaluator:
1815
def __init__(self, coco_gt, iou_types):
1916
assert isinstance(iou_types, (list, tuple))
2017
coco_gt = copy.deepcopy(coco_gt)
@@ -34,7 +31,8 @@ def update(self, predictions):
3431

3532
for iou_type in self.iou_types:
3633
results = self.prepare(predictions, iou_type)
37-
coco_dt = loadRes(self.coco_gt, results) if results else COCO()
34+
with redirect_stdout(io.StringIO()):
35+
coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
3836
coco_eval = self.coco_eval[iou_type]
3937

4038
coco_eval.cocoDt = coco_dt
@@ -54,18 +52,17 @@ def accumulate(self):
5452

5553
def summarize(self):
5654
for iou_type, coco_eval in self.coco_eval.items():
57-
print("IoU metric: {}".format(iou_type))
55+
print(f"IoU metric: {iou_type}")
5856
coco_eval.summarize()
5957

6058
def prepare(self, predictions, iou_type):
6159
if iou_type == "bbox":
6260
return self.prepare_for_coco_detection(predictions)
63-
elif iou_type == "segm":
61+
if iou_type == "segm":
6462
return self.prepare_for_coco_segmentation(predictions)
65-
elif iou_type == "keypoints":
63+
if iou_type == "keypoints":
6664
return self.prepare_for_coco_keypoint(predictions)
67-
else:
68-
raise ValueError("Unknown iou type {}".format(iou_type))
65+
raise ValueError(f"Unknown iou type {iou_type}")
6966

7067
def prepare_for_coco_detection(self, predictions):
7168
coco_results = []
@@ -190,161 +187,7 @@ def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
190187
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
191188

192189

193-
#################################################################
194-
# From pycocotools, just removed the prints and fixed
195-
# a Python3 bug about unicode not defined
196-
#################################################################
197-
198-
# Ideally, pycocotools wouldn't have hard-coded prints
199-
# so that we could avoid copy-pasting those two functions
200-
201-
def createIndex(self):
202-
# create index
203-
# print('creating index...')
204-
anns, cats, imgs = {}, {}, {}
205-
imgToAnns, catToImgs = defaultdict(list), defaultdict(list)
206-
if 'annotations' in self.dataset:
207-
for ann in self.dataset['annotations']:
208-
imgToAnns[ann['image_id']].append(ann)
209-
anns[ann['id']] = ann
210-
211-
if 'images' in self.dataset:
212-
for img in self.dataset['images']:
213-
imgs[img['id']] = img
214-
215-
if 'categories' in self.dataset:
216-
for cat in self.dataset['categories']:
217-
cats[cat['id']] = cat
218-
219-
if 'annotations' in self.dataset and 'categories' in self.dataset:
220-
for ann in self.dataset['annotations']:
221-
catToImgs[ann['category_id']].append(ann['image_id'])
222-
223-
# print('index created!')
224-
225-
# create class members
226-
self.anns = anns
227-
self.imgToAnns = imgToAnns
228-
self.catToImgs = catToImgs
229-
self.imgs = imgs
230-
self.cats = cats
231-
232-
233-
maskUtils = mask_util
234-
235-
236-
def loadRes(self, resFile):
237-
"""
238-
Load result file and return a result api object.
239-
Args:
240-
self (obj): coco object with ground truth annotations
241-
resFile (str): file name of result file
242-
Returns:
243-
res (obj): result api object
244-
"""
245-
res = COCO()
246-
res.dataset['images'] = [img for img in self.dataset['images']]
247-
248-
# print('Loading and preparing results...')
249-
# tic = time.time()
250-
if isinstance(resFile, torch._six.string_classes):
251-
anns = json.load(open(resFile))
252-
elif type(resFile) == np.ndarray:
253-
anns = self.loadNumpyAnnotations(resFile)
254-
else:
255-
anns = resFile
256-
assert type(anns) == list, 'results in not an array of objects'
257-
annsImgIds = [ann['image_id'] for ann in anns]
258-
assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
259-
'Results do not correspond to current coco set'
260-
if 'caption' in anns[0]:
261-
imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns])
262-
res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds]
263-
for id, ann in enumerate(anns):
264-
ann['id'] = id + 1
265-
elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
266-
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
267-
for id, ann in enumerate(anns):
268-
bb = ann['bbox']
269-
x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
270-
if 'segmentation' not in ann:
271-
ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
272-
ann['area'] = bb[2] * bb[3]
273-
ann['id'] = id + 1
274-
ann['iscrowd'] = 0
275-
elif 'segmentation' in anns[0]:
276-
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
277-
for id, ann in enumerate(anns):
278-
# now only support compressed RLE format as segmentation results
279-
ann['area'] = maskUtils.area(ann['segmentation'])
280-
if 'bbox' not in ann:
281-
ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
282-
ann['id'] = id + 1
283-
ann['iscrowd'] = 0
284-
elif 'keypoints' in anns[0]:
285-
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
286-
for id, ann in enumerate(anns):
287-
s = ann['keypoints']
288-
x = s[0::3]
289-
y = s[1::3]
290-
x1, x2, y1, y2 = np.min(x), np.max(x), np.min(y), np.max(y)
291-
ann['area'] = (x2 - x1) * (y2 - y1)
292-
ann['id'] = id + 1
293-
ann['bbox'] = [x1, y1, x2 - x1, y2 - y1]
294-
# print('DONE (t={:0.2f}s)'.format(time.time()- tic))
295-
296-
res.dataset['annotations'] = anns
297-
createIndex(res)
298-
return res
299-
300-
301-
def evaluate(self):
302-
'''
303-
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
304-
:return: None
305-
'''
306-
# tic = time.time()
307-
# print('Running per image evaluation...')
308-
p = self.params
309-
# add backward compatibility if useSegm is specified in params
310-
if p.useSegm is not None:
311-
p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
312-
print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
313-
# print('Evaluate annotation type *{}*'.format(p.iouType))
314-
p.imgIds = list(np.unique(p.imgIds))
315-
if p.useCats:
316-
p.catIds = list(np.unique(p.catIds))
317-
p.maxDets = sorted(p.maxDets)
318-
self.params = p
319-
320-
self._prepare()
321-
# loop through images, area range, max detection number
322-
catIds = p.catIds if p.useCats else [-1]
323-
324-
if p.iouType == 'segm' or p.iouType == 'bbox':
325-
computeIoU = self.computeIoU
326-
elif p.iouType == 'keypoints':
327-
computeIoU = self.computeOks
328-
self.ious = {
329-
(imgId, catId): computeIoU(imgId, catId)
330-
for imgId in p.imgIds
331-
for catId in catIds}
332-
333-
evaluateImg = self.evaluateImg
334-
maxDet = p.maxDets[-1]
335-
evalImgs = [
336-
evaluateImg(imgId, catId, areaRng, maxDet)
337-
for catId in catIds
338-
for areaRng in p.areaRng
339-
for imgId in p.imgIds
340-
]
341-
# this is NOT in the pycocotools code, but could be done outside
342-
evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
343-
self._paramsEval = copy.deepcopy(self.params)
344-
# toc = time.time()
345-
# print('DONE (t={:0.2f}s).'.format(toc-tic))
346-
return p.imgIds, evalImgs
347-
348-
#################################################################
349-
# end of straight copy from pycocotools, just removing the prints
350-
#################################################################
190+
def evaluate(imgs):
191+
with redirect_stdout(io.StringIO()):
192+
imgs.evaluate()
193+
return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds))

0 commit comments

Comments
 (0)