Skip to content

Commit

Permalink
Don't cache input shapes because they can vary
Browse files Browse the repository at this point in the history
Input shape transforms were previously only calculated once (i.e. for
one frame / one image) and then cached for all subsequent calls.
This is fine in 99% of use cases but if you want to process differently
shaped inputs in the same deface process (given multiple input paths),
this leads to crashes or even silently incorrect outputs.

Transforms now have to be recalculated for each frame but this is
very cheap, so it's okay to do this for the sake of better stability.

Fixes #41
  • Loading branch information
mdraw committed Jul 4, 2023
1 parent a023f97 commit 814e6a2
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions deface/centerface.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,13 @@ def dynamicize_shapes(static_model):

def __call__(self, img, threshold=0.5):
img = ensure_rgb(img)
self.orig_shape = img.shape[:2]
if self.in_shape is None:
self.in_shape = self.orig_shape[::-1]
if not hasattr(self, 'h_new'): # First call, need to compute sizes
self.w_new, self.h_new, self.scale_w, self.scale_h = self.transform(self.in_shape)
orig_shape = img.shape[:2]
in_shape = orig_shape[::-1] if self.in_shape is None else self.in_shape
# Compute sizes
w_new, h_new, scale_w, scale_h = self.shape_transform(in_shape, orig_shape)

blob = cv2.dnn.blobFromImage(
img, scalefactor=1.0, size=(self.w_new, self.h_new),
img, scalefactor=1.0, size=(w_new, h_new),
mean=(0, 0, 0), swapRB=False, crop=False
)
if self.backend == 'opencv':
Expand All @@ -111,18 +110,19 @@ def __call__(self, img, threshold=0.5):
heatmap, scale, offset, lms = self.sess.run(self.onnx_output_names, {self.onnx_input_name: blob})
else:
raise RuntimeError(f'Unknown backend {self.backend}')
dets, lms = self.decode(heatmap, scale, offset, lms, (self.h_new, self.w_new), threshold=threshold)
dets, lms = self.decode(heatmap, scale, offset, lms, (h_new, w_new), threshold=threshold)
if len(dets) > 0:
dets[:, 0:4:2], dets[:, 1:4:2] = dets[:, 0:4:2] / self.scale_w, dets[:, 1:4:2] / self.scale_h
lms[:, 0:10:2], lms[:, 1:10:2] = lms[:, 0:10:2] / self.scale_w, lms[:, 1:10:2] / self.scale_h
dets[:, 0:4:2], dets[:, 1:4:2] = dets[:, 0:4:2] / scale_w, dets[:, 1:4:2] / scale_h
lms[:, 0:10:2], lms[:, 1:10:2] = lms[:, 0:10:2] / scale_w, lms[:, 1:10:2] / scale_h
else:
dets = np.empty(shape=[0, 5], dtype=np.float32)
lms = np.empty(shape=[0, 10], dtype=np.float32)

return dets, lms

def transform(self, in_shape):
h_orig, w_orig = self.orig_shape
@staticmethod
def shape_transform(in_shape, orig_shape):
h_orig, w_orig = orig_shape
w_new, h_new = in_shape
# Make spatial dims divisible by 32
w_new, h_new = int(np.ceil(w_new / 32) * 32), int(np.ceil(h_new / 32) * 32)
Expand Down

0 comments on commit 814e6a2

Please sign in to comment.