Skip to content

Commit

Permalink
Face det (#9179)
Browse files Browse the repository at this point in the history
* add "WIDERFaceEvalDataset" and "WiderFaceOnlineMetric"

* add "support widerface eval online"
  • Loading branch information
leo-q8 authored Oct 18, 2024
1 parent 7ba9f16 commit 666f597
Show file tree
Hide file tree
Showing 9 changed files with 294 additions and 32 deletions.
7 changes: 4 additions & 3 deletions configs/datasets/wider_face.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ TrainDataset:
data_fields: ['image', 'gt_bbox', 'gt_class']

EvalDataset:
!WIDERFaceDataSet
!WIDERFaceValDataset
dataset_dir: dataset/wider_face
anno_path: wider_face_split/wider_face_val_bbx_gt.txt
image_dir: WIDER_val/images
data_fields: ['image']
anno_path: wider_face_split/wider_face_val_bbx_gt.txt
gt_mat_path: WIDER_val/ground_truth
data_fields: ['image', 'gt_bbox', 'gt_class', 'ori_gt_bbox']

TestDataset:
!ImageFolder
Expand Down
7 changes: 5 additions & 2 deletions configs/face_detection/_base_/face_reader.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
worker_num: 2
worker_num: 8
TrainReader:
inputs_def:
num_max_boxes: 90
Expand All @@ -23,7 +23,7 @@ TrainReader:
batch_transforms:
- NormalizeImage: {mean: [123, 117, 104], std: [127.502231, 127.502231, 127.502231], is_scale: false}
- Permute: {}
batch_size: 8
batch_size: 16
shuffle: true
drop_last: true

Expand All @@ -34,6 +34,9 @@ EvalReader:
- NormalizeImage: {mean: [123, 117, 104], std: [127.502231, 127.502231, 127.502231], is_scale: false}
- Permute: {}
batch_size: 1
collate_samples: false
shuffle: false
drop_last: false


TestReader:
Expand Down
2 changes: 1 addition & 1 deletion configs/face_detection/blazeface_1000e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ _BASE_: [
'_base_/face_reader.yml',
]
weights: output/blazeface_1000e/model_final
multi_scale_eval: True
snapshot_epoch: 10
2 changes: 1 addition & 1 deletion configs/face_detection/blazeface_fpn_ssh_1000e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ _BASE_: [
'_base_/face_reader.yml',
]
weights: output/blazeface_fpn_ssh_1000e/model_final
multi_scale_eval: True
snapshot_epoch: 10
81 changes: 81 additions & 0 deletions ppdet/data/source/widerface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
import os
import numpy as np
from scipy.io import loadmat

from ppdet.core.workspace import register, serializable
from .dataset import DetDataset
Expand Down Expand Up @@ -178,3 +180,82 @@ def _load_file_list(self, input_txt):
def widerface_label():
labels_map = {'face': 0}
return labels_map


@register
@serializable
class WIDERFaceValDataset(WIDERFaceDataSet):
def __init__(self,
dataset_dir=None,
image_dir=None,
anno_path=None,
gt_mat_path=None,
data_fields=['image'],
sample_num=-1,
with_lmk=False):
super().__init__(
dataset_dir=dataset_dir,
image_dir=image_dir,
anno_path=anno_path,
data_fields=data_fields,
sample_num=sample_num,
with_lmk=with_lmk)
self.gt_mat_path = gt_mat_path
self.val_mat = os.path.join(self.dataset_dir, self.gt_mat_path, 'wider_face_val.mat')
self.hard_mat_path = os.path.join(self.dataset_dir, self.gt_mat_path, 'wider_hard_val.mat')
self.medium_mat_path = os.path.join(self.dataset_dir, self.gt_mat_path, 'wider_medium_val.mat')
self.easy_mat_path = os.path.join(self.dataset_dir, self.gt_mat_path, 'wider_easy_val.mat')

assert os.path.exists(self.val_mat), f'{self.val_mat} not exist'
assert os.path.exists(self.hard_mat_path), f'{self.hard_mat_path} not exist'
assert os.path.exists(self.medium_mat_path), f'{self.medium_mat_path} not exist'
assert os.path.exists(self.easy_mat_path), f'{self.easy_mat_path} not exist'

def parse_dataset(self):
super().parse_dataset()

box_list, flie_list, event_list, hard_info_list, medium_info_list, \
easy_info_list = self.get_gt_infos()
setting_infos = [easy_info_list, medium_info_list, hard_info_list]
settings = ['easy', 'medium', 'hard']
info_by_name = defaultdict(dict)
for setting_id in range(3):
info_list = setting_infos[setting_id]
setting = settings[setting_id]
for i in range(len(event_list)):
img_list = flie_list[i][0]
gt_box_list = box_list[i][0]
sub_info_list = info_list[i][0]
for j in range(len(img_list)):
img_name = str(img_list[j][0][0])
gt_boxes = gt_box_list[j][0].astype(np.float32)
info_by_name[img_name]['gt_ori_bbox'] = gt_boxes

keep_index = sub_info_list[j][0]
ignore = np.zeros(gt_boxes.shape[0])
if len(keep_index) != 0:
ignore[keep_index-1] = 1
info_by_name[img_name][f'gt_{setting}_ignore'] = ignore

for roidb in self.roidbs:
img_file = roidb['im_file'].split('/')[-1]
img_name = ".".join(img_file.split(".")[:-1])
roidb.update(info_by_name[img_name])

def get_gt_infos(self):
""" gt dir: (wider_face_val.mat, wider_easy_val.mat, wider_medium_val.mat, wider_hard_val.mat)"""

val_mat = loadmat(self.val_mat)
hard_mat = loadmat(self.hard_mat_path)
medium_mat = loadmat(self.medium_mat_path)
easy_mat = loadmat(self.easy_mat_path)

box_list = val_mat['face_bbx_list']
file_list = val_mat['file_list']
event_list = val_mat['event_list']

hard_info_list = hard_mat['gt_list']
medium_info_list = medium_mat['gt_list']
easy_info_list = easy_mat['gt_list']

return box_list, file_list, event_list, hard_info_list, medium_info_list, easy_info_list
11 changes: 3 additions & 8 deletions ppdet/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ def _init_callbacks(self):
self._compose_callback = ComposeCallback(self._callbacks)
elif self.mode == 'eval':
self._callbacks = [LogPrinter(self)]
if self.cfg.metric == 'WiderFace':
self._callbacks.append(WiferFaceEval(self))
# if self.cfg.metric == 'WiderFace':
# self._callbacks.append(WiferFaceEval(self))
self._compose_callback = ComposeCallback(self._callbacks)
elif self.mode == 'test' and self.cfg.get('use_vdl', False):
self._callbacks = [VisualDLWriter(self)]
Expand Down Expand Up @@ -381,13 +381,8 @@ def _init_metrics(self, validate=False):
save_prediction_only=save_prediction_only)
]
elif self.cfg.metric == 'WiderFace':
multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True
self._metrics = [
WiderFaceMetric(
image_dir=os.path.join(self.dataset.dataset_dir,
self.dataset.image_dir),
anno_file=self.dataset.get_anno(),
multi_scale=multi_scale)
WiderFaceMetric()
]
elif self.cfg.metric == 'KeyPointTopDownCOCOEval':
eval_dataset = self.cfg['EvalDataset']
Expand Down
104 changes: 88 additions & 16 deletions ppdet/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@

from .map_utils import prune_zero_padding, DetectionMAP
from .coco_utils import get_infer_results, cocoapi_eval
from .widerface_utils import face_eval_run
from .widerface_utils import (face_eval_run, image_eval, img_pr_info,
dataset_pr_info, voc_ap)
from ppdet.data.source.category import get_categories
from ppdet.modeling.rbox_utils import poly2rbox_np

Expand Down Expand Up @@ -337,22 +338,93 @@ def get_results(self):


class WiderFaceMetric(Metric):
def __init__(self, image_dir, anno_file, multi_scale=True):
self.image_dir = image_dir
self.anno_file = anno_file
self.multi_scale = multi_scale
self.clsid2catid, self.catid2name = get_categories('widerface')

def update(self, model):

face_eval_run(
model,
self.image_dir,
self.anno_file,
pred_dir='output/pred',
eval_mode='widerface',
multi_scale=self.multi_scale)
def __init__(self, iou_thresh=0.5):
self.iou_thresh = iou_thresh
self.reset()

def reset(self):
self.pred_boxes_list = []
self.gt_boxes_list = []
self.aps = []

self.hard_ignore_list = []
self.medium_ignore_list = []
self.easy_ignore_list = []

def update(self, data, outs):
batch_pred_bboxes = outs['bbox']
batch_pred_bboxes_num = outs['bbox_num']
assert len(batch_pred_bboxes_num) == len(data['gt_bbox'])
batch_size = len(data['gt_bbox'])
box_cnt = 0
for batch_id in range(batch_size):
pred_bboxes_num = batch_pred_bboxes_num[batch_id]
pred_bboxes = batch_pred_bboxes[box_cnt: box_cnt +
pred_bboxes_num].numpy()
box_cnt += pred_bboxes_num

det_conf = pred_bboxes[:, 1]
det_xmin = pred_bboxes[:, 2]
det_ymin = pred_bboxes[:, 3]
det_xmax = pred_bboxes[:, 4]
det_ymax = pred_bboxes[:, 5]
det = np.column_stack((det_xmin, det_ymin, det_xmax,
det_ymax, det_conf))
self.pred_boxes_list.append(det) # xyxy conf
self.gt_boxes_list.append(data['gt_ori_bbox'][batch_id].numpy()) # xywh
self.hard_ignore_list.append(
data['gt_hard_ignore'][batch_id].numpy())
self.medium_ignore_list.append(
data['gt_medium_ignore'][batch_id].numpy())
self.easy_ignore_list.append(
data['gt_easy_ignore'][batch_id].numpy())

def accumulate(self):
total_num = len(self.gt_boxes_list)
settings = ['easy', 'medium', 'hard']
setting_ingores = [self.easy_ignore_list,
self.medium_ignore_list,
self.hard_ignore_list]
thresh_num = 1000
aps = []
for setting_id in range(3):
count_face = 0
pr_curve = np.zeros((thresh_num, 2)).astype(np.float32)
gt_ignore_list = setting_ingores[setting_id]
for i in range(total_num):
pred_boxes = self.pred_boxes_list[i] # xyxy conf
gt_boxes = self.gt_boxes_list[i] # xywh
ignore = gt_ignore_list[i]
count_face += np.sum(ignore)

if len(gt_boxes) == 0 or len(pred_boxes) == 0:
continue
pred_recall, proposal_list = image_eval(pred_boxes, gt_boxes,
ignore, self.iou_thresh)
_img_pr_info = img_pr_info(thresh_num, pred_boxes,
proposal_list, pred_recall)
pr_curve += _img_pr_info
pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face)

propose = pr_curve[:, 0]
recall = pr_curve[:, 1]

ap = voc_ap(recall, propose)
aps.append(ap)
self.aps = aps

def log(self):
logger.info("==================== Results ====================")
logger.info("Easy Val AP: {}".format(self.aps[0]))
logger.info("Medium Val AP: {}".format(self.aps[1]))
logger.info("Hard Val AP: {}".format(self.aps[2]))
logger.info("=================================================")

def get_results(self):
return {
'easy_ap': self.aps[0],
'medium_ap': self.aps[1],
'hard_ap': self.aps[2]}

class RBoxMetric(Metric):
def __init__(self, anno_file, **kwargs):
Expand Down
Loading

0 comments on commit 666f597

Please sign in to comment.