diff --git a/configs/datasets/wider_face.yml b/configs/datasets/wider_face.yml index cc01378d728..7cba5698d7b 100644 --- a/configs/datasets/wider_face.yml +++ b/configs/datasets/wider_face.yml @@ -9,11 +9,12 @@ TrainDataset: data_fields: ['image', 'gt_bbox', 'gt_class'] EvalDataset: - !WIDERFaceDataSet + !WIDERFaceValDataset dataset_dir: dataset/wider_face - anno_path: wider_face_split/wider_face_val_bbx_gt.txt image_dir: WIDER_val/images - data_fields: ['image'] + anno_path: wider_face_split/wider_face_val_bbx_gt.txt + gt_mat_path: WIDER_val/ground_truth + data_fields: ['image', 'gt_bbox', 'gt_class', 'ori_gt_bbox'] TestDataset: !ImageFolder diff --git a/configs/face_detection/_base_/face_reader.yml b/configs/face_detection/_base_/face_reader.yml index 5a25e8aa0f1..c26051a4d81 100644 --- a/configs/face_detection/_base_/face_reader.yml +++ b/configs/face_detection/_base_/face_reader.yml @@ -1,4 +1,4 @@ -worker_num: 2 +worker_num: 8 TrainReader: inputs_def: num_max_boxes: 90 @@ -23,7 +23,7 @@ TrainReader: batch_transforms: - NormalizeImage: {mean: [123, 117, 104], std: [127.502231, 127.502231, 127.502231], is_scale: false} - Permute: {} - batch_size: 8 + batch_size: 16 shuffle: true drop_last: true @@ -34,6 +34,9 @@ EvalReader: - NormalizeImage: {mean: [123, 117, 104], std: [127.502231, 127.502231, 127.502231], is_scale: false} - Permute: {} batch_size: 1 + collate_samples: false + shuffle: false + drop_last: false TestReader: diff --git a/configs/face_detection/blazeface_1000e.yml b/configs/face_detection/blazeface_1000e.yml index 58fc908f81f..9178b44a0a4 100644 --- a/configs/face_detection/blazeface_1000e.yml +++ b/configs/face_detection/blazeface_1000e.yml @@ -6,4 +6,4 @@ _BASE_: [ '_base_/face_reader.yml', ] weights: output/blazeface_1000e/model_final -multi_scale_eval: True +snapshot_epoch: 10 diff --git a/configs/face_detection/blazeface_fpn_ssh_1000e.yml b/configs/face_detection/blazeface_fpn_ssh_1000e.yml index 21dbd264438..632a4aaed3f 100644 --- a/configs/face_detection/blazeface_fpn_ssh_1000e.yml +++ b/configs/face_detection/blazeface_fpn_ssh_1000e.yml @@ -6,4 +6,4 @@ _BASE_: [ '_base_/face_reader.yml', ] weights: output/blazeface_fpn_ssh_1000e/model_final -multi_scale_eval: True +snapshot_epoch: 10 diff --git a/ppdet/data/source/widerface.py b/ppdet/data/source/widerface.py index a17c2aaf8a2..d7a95de0722 100644 --- a/ppdet/data/source/widerface.py +++ b/ppdet/data/source/widerface.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict import os import numpy as np +from scipy.io import loadmat from ppdet.core.workspace import register, serializable from .dataset import DetDataset @@ -178,3 +180,82 @@ def _load_file_list(self, input_txt): def widerface_label(): labels_map = {'face': 0} return labels_map + + +@register +@serializable +class WIDERFaceValDataset(WIDERFaceDataSet): + def __init__(self, + dataset_dir=None, + image_dir=None, + anno_path=None, + gt_mat_path=None, + data_fields=['image'], + sample_num=-1, + with_lmk=False): + super().__init__( + dataset_dir=dataset_dir, + image_dir=image_dir, + anno_path=anno_path, + data_fields=data_fields, + sample_num=sample_num, + with_lmk=with_lmk) + self.gt_mat_path = gt_mat_path + self.val_mat = os.path.join(self.dataset_dir, self.gt_mat_path, 'wider_face_val.mat') + self.hard_mat_path = os.path.join(self.dataset_dir, self.gt_mat_path, 'wider_hard_val.mat') + self.medium_mat_path = os.path.join(self.dataset_dir, self.gt_mat_path, 'wider_medium_val.mat') + self.easy_mat_path = os.path.join(self.dataset_dir, self.gt_mat_path, 'wider_easy_val.mat') + + assert os.path.exists(self.val_mat), f'{self.val_mat} not exist' + assert os.path.exists(self.hard_mat_path), f'{self.hard_mat_path} not exist' + assert os.path.exists(self.medium_mat_path), f'{self.medium_mat_path} not exist' + assert os.path.exists(self.easy_mat_path), f'{self.easy_mat_path} not exist' + + def parse_dataset(self): + super().parse_dataset() + + box_list, flie_list, event_list, hard_info_list, medium_info_list, \ + easy_info_list = self.get_gt_infos() + setting_infos = [easy_info_list, medium_info_list, hard_info_list] + settings = ['easy', 'medium', 'hard'] + info_by_name = defaultdict(dict) + for setting_id in range(3): + info_list = setting_infos[setting_id] + setting = settings[setting_id] + for i in range(len(event_list)): + img_list = flie_list[i][0] + gt_box_list = box_list[i][0] + sub_info_list = info_list[i][0] + for j in range(len(img_list)): + img_name = str(img_list[j][0][0]) + gt_boxes = gt_box_list[j][0].astype(np.float32) + info_by_name[img_name]['gt_ori_bbox'] = gt_boxes + + keep_index = sub_info_list[j][0] + ignore = np.zeros(gt_boxes.shape[0]) + if len(keep_index) != 0: + ignore[keep_index-1] = 1 + info_by_name[img_name][f'gt_{setting}_ignore'] = ignore + + for roidb in self.roidbs: + img_file = roidb['im_file'].split('/')[-1] + img_name = ".".join(img_file.split(".")[:-1]) + roidb.update(info_by_name[img_name]) + + def get_gt_infos(self): + """ gt dir: (wider_face_val.mat, wider_easy_val.mat, wider_medium_val.mat, wider_hard_val.mat)""" + + val_mat = loadmat(self.val_mat) + hard_mat = loadmat(self.hard_mat_path) + medium_mat = loadmat(self.medium_mat_path) + easy_mat = loadmat(self.easy_mat_path) + + box_list = val_mat['face_bbx_list'] + file_list = val_mat['file_list'] + event_list = val_mat['event_list'] + + hard_info_list = hard_mat['gt_list'] + medium_info_list = medium_mat['gt_list'] + easy_info_list = easy_mat['gt_list'] + + return box_list, file_list, event_list, hard_info_list, medium_info_list, easy_info_list \ No newline at end of file diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 90addeb7d52..da342a7a4c5 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -276,8 +276,8 @@ def _init_callbacks(self): self._compose_callback = ComposeCallback(self._callbacks) elif self.mode == 'eval': self._callbacks = [LogPrinter(self)] - if self.cfg.metric == 'WiderFace': - self._callbacks.append(WiferFaceEval(self)) + # if self.cfg.metric == 'WiderFace': + # self._callbacks.append(WiferFaceEval(self)) self._compose_callback = ComposeCallback(self._callbacks) elif self.mode == 'test' and self.cfg.get('use_vdl', False): self._callbacks = [VisualDLWriter(self)] @@ -381,13 +381,8 @@ def _init_metrics(self, validate=False): save_prediction_only=save_prediction_only) ] elif self.cfg.metric == 'WiderFace': - multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True self._metrics = [ - WiderFaceMetric( - image_dir=os.path.join(self.dataset.dataset_dir, - self.dataset.image_dir), - anno_file=self.dataset.get_anno(), - multi_scale=multi_scale) + WiderFaceMetric() ] elif self.cfg.metric == 'KeyPointTopDownCOCOEval': eval_dataset = self.cfg['EvalDataset'] diff --git a/ppdet/metrics/metrics.py b/ppdet/metrics/metrics.py index 2fac677299c..a534814d0e1 100644 --- a/ppdet/metrics/metrics.py +++ b/ppdet/metrics/metrics.py @@ -27,7 +27,8 @@ from .map_utils import prune_zero_padding, DetectionMAP from .coco_utils import get_infer_results, cocoapi_eval -from .widerface_utils import face_eval_run +from .widerface_utils import (face_eval_run, image_eval, img_pr_info, + dataset_pr_info, voc_ap) from ppdet.data.source.category import get_categories from ppdet.modeling.rbox_utils import poly2rbox_np @@ -337,22 +338,93 @@ def get_results(self): class WiderFaceMetric(Metric): - def __init__(self, image_dir, anno_file, multi_scale=True): - self.image_dir = image_dir - self.anno_file = anno_file - self.multi_scale = multi_scale - self.clsid2catid, self.catid2name = get_categories('widerface') - - def update(self, model): - - face_eval_run( - model, - self.image_dir, - self.anno_file, - pred_dir='output/pred', - eval_mode='widerface', - multi_scale=self.multi_scale) + def __init__(self, iou_thresh=0.5): + self.iou_thresh = iou_thresh + self.reset() + def reset(self): + self.pred_boxes_list = [] + self.gt_boxes_list = [] + self.aps = [] + + self.hard_ignore_list = [] + self.medium_ignore_list = [] + self.easy_ignore_list = [] + + def update(self, data, outs): + batch_pred_bboxes = outs['bbox'] + batch_pred_bboxes_num = outs['bbox_num'] + assert len(batch_pred_bboxes_num) == len(data['gt_bbox']) + batch_size = len(data['gt_bbox']) + box_cnt = 0 + for batch_id in range(batch_size): + pred_bboxes_num = batch_pred_bboxes_num[batch_id] + pred_bboxes = batch_pred_bboxes[box_cnt: box_cnt + + pred_bboxes_num].numpy() + box_cnt += pred_bboxes_num + + det_conf = pred_bboxes[:, 1] + det_xmin = pred_bboxes[:, 2] + det_ymin = pred_bboxes[:, 3] + det_xmax = pred_bboxes[:, 4] + det_ymax = pred_bboxes[:, 5] + det = np.column_stack((det_xmin, det_ymin, det_xmax, + det_ymax, det_conf)) + self.pred_boxes_list.append(det) # xyxy conf + self.gt_boxes_list.append(data['gt_ori_bbox'][batch_id].numpy()) # xywh + self.hard_ignore_list.append( + data['gt_hard_ignore'][batch_id].numpy()) + self.medium_ignore_list.append( + data['gt_medium_ignore'][batch_id].numpy()) + self.easy_ignore_list.append( + data['gt_easy_ignore'][batch_id].numpy()) + + def accumulate(self): + total_num = len(self.gt_boxes_list) + settings = ['easy', 'medium', 'hard'] + setting_ingores = [self.easy_ignore_list, + self.medium_ignore_list, + self.hard_ignore_list] + thresh_num = 1000 + aps = [] + for setting_id in range(3): + count_face = 0 + pr_curve = np.zeros((thresh_num, 2)).astype(np.float32) + gt_ignore_list = setting_ingores[setting_id] + for i in range(total_num): + pred_boxes = self.pred_boxes_list[i] # xyxy conf + gt_boxes = self.gt_boxes_list[i] # xywh + ignore = gt_ignore_list[i] + count_face += np.sum(ignore) + + if len(gt_boxes) == 0 or len(pred_boxes) == 0: + continue + pred_recall, proposal_list = image_eval(pred_boxes, gt_boxes, + ignore, self.iou_thresh) + _img_pr_info = img_pr_info(thresh_num, pred_boxes, + proposal_list, pred_recall) + pr_curve += _img_pr_info + pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face) + + propose = pr_curve[:, 0] + recall = pr_curve[:, 1] + + ap = voc_ap(recall, propose) + aps.append(ap) + self.aps = aps + + def log(self): + logger.info("==================== Results ====================") + logger.info("Easy Val AP: {}".format(self.aps[0])) + logger.info("Medium Val AP: {}".format(self.aps[1])) + logger.info("Hard Val AP: {}".format(self.aps[2])) + logger.info("=================================================") + + def get_results(self): + return { + 'easy_ap': self.aps[0], + 'medium_ap': self.aps[1], + 'hard_ap': self.aps[2]} class RBoxMetric(Metric): def __init__(self, anno_file, **kwargs): diff --git a/ppdet/metrics/widerface_utils.py b/ppdet/metrics/widerface_utils.py index 2f64bf6d50a..7247fbe3d19 100644 --- a/ppdet/metrics/widerface_utils.py +++ b/ppdet/metrics/widerface_utils.py @@ -389,3 +389,114 @@ def lmk2out(results, is_bbox_normalized=False): xywh_res.append(lmk_res) k += 1 return xywh_res + +def image_eval(pred, gt, ignore, iou_thresh): + """ single image evaluation + pred: Nx5 xyxys + gt: Nx4 xywh + ignore: + """ + _pred = pred.copy() + _gt = gt.copy() + pred_recall = np.zeros(_pred.shape[0]) + recall_list = np.zeros(_gt.shape[0]) + proposal_list = np.ones(_pred.shape[0]) + + _gt[:, 2] = _gt[:, 2] + _gt[:, 0] + _gt[:, 3] = _gt[:, 3] + _gt[:, 1] + + overlaps = bbox_overlaps(_pred[:, :4], _gt) + + for h in range(_pred.shape[0]): + + gt_overlap = overlaps[h] + max_overlap, max_idx = gt_overlap.max(), gt_overlap.argmax() + if max_overlap >= iou_thresh: + if ignore[max_idx] == 0: + recall_list[max_idx] = -1 + proposal_list[h] = -1 + elif recall_list[max_idx] == 0: + recall_list[max_idx] = 1 + + r_keep_index = np.where(recall_list == 1)[0] + pred_recall[h] = len(r_keep_index) + return pred_recall, proposal_list + + +def bbox_overlaps(boxes1, boxes2): + """ + Parameters + ---------- + boxes1: (N, 4) ndarray of float + boxes2: (K, 4) ndarray of float + Returns + ------- + overlaps: (N, K) ndarray of overlap between boxes1 and boxes2 + """ + # Calculate the area of each box + box_areas1 = (boxes1[:, 2] - boxes1[:, 0] + 1) * ( + boxes1[:, 3] - boxes1[:, 1] + 1) + box_areas2 = (boxes2[:, 2] - boxes2[:, 0] + 1) * ( + boxes2[:, 3] - boxes2[:, 1] + 1) + # Calculate the intersection areas + iw = np.minimum(boxes1[:, None, 2], boxes2[None, :, 2]) - np.maximum( + boxes1[:, None, 0], boxes2[None, :, 0]) + 1 + ih = np.minimum(boxes1[:, None, 3], boxes2[None, :, 3]) - np.maximum( + boxes1[:, None, 1], boxes2[None, :, 1]) + 1 + # Ensure that the intersection width and height are non-negative + iw = np.maximum(iw, 0) + ih = np.maximum(ih, 0) + # Calculate the intersection area + intersection = iw * ih + # Calculate the union area + union = box_areas1[:, None] + box_areas2[None, :] - intersection + union = box_areas1[:, None] + box_areas2[None, :] - intersection + union = np.maximum(union, 1e-8) + # Calculate the overlaps (intersection over union) + overlaps = intersection / union + return overlaps + + +def img_pr_info(thresh_num, pred_info, proposal_list, pred_recall): + pr_info = np.zeros((thresh_num, 2)).astype('float') + for t in range(thresh_num): + + thresh = 1 - (t+1)/thresh_num + r_index = np.where(pred_info[:, 4] >= thresh)[0] + if len(r_index) == 0: + pr_info[t, 0] = 0 + pr_info[t, 1] = 0 + else: + r_index = r_index[-1] + p_index = np.where(proposal_list[:r_index+1] == 1)[0] + pr_info[t, 0] = len(p_index) + pr_info[t, 1] = pred_recall[r_index] + return pr_info + + +def dataset_pr_info(thresh_num, pr_curve, count_face): + _pr_curve = np.zeros((thresh_num, 2)) + for i in range(thresh_num): + _pr_curve[i, 0] = pr_curve[i, 1] / pr_curve[i, 0] + _pr_curve[i, 1] = pr_curve[i, 1] / count_face + return _pr_curve + + +def voc_ap(rec, prec): + + # correct AP calculation + # first append sentinel values at the end + mrec = np.concatenate(([0.], rec, [1.])) + mpre = np.concatenate(([0.], prec, [0.])) + + # compute the precision envelope + for i in range(mpre.size - 1, 0, -1): + mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) + + # to calculate area under PR curve, look for points + # where X axis (recall) changes value + i = np.where(mrec[1:] != mrec[:-1])[0] + + # and sum (\Delta recall) * prec + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) + return ap \ No newline at end of file diff --git a/ppdet/modeling/transformers/hybrid_encoder.py b/ppdet/modeling/transformers/hybrid_encoder.py index 9038e845c03..ead15f26b18 100644 --- a/ppdet/modeling/transformers/hybrid_encoder.py +++ b/ppdet/modeling/transformers/hybrid_encoder.py @@ -265,7 +265,6 @@ def forward(self, feats, for_mot=False, is_teacher=False): feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx]( feat_heigh) inner_outs[0] = feat_heigh - upsample_feat = F.interpolate( feat_heigh, scale_factor=2., mode="nearest") inner_out = self.fpn_blocks[len(self.in_channels) - 1 - idx](