Skip to content

Commit

Permalink
[models] Unleash FAST model (reparameterization logic) (#1494)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Mar 4, 2024
1 parent 2bf9166 commit 62d94ff
Show file tree
Hide file tree
Showing 8 changed files with 351 additions and 84 deletions.
78 changes: 39 additions & 39 deletions docs/source/using_doctr/using_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,45 +34,45 @@ The following architectures are currently supported:
For a comprehensive comparison, we have compiled a detailed benchmark on publicly available datasets:


+-----------------------------------------------------------------------------------+----------------------------+----------------------------+--------------------+
| | FUNSD | CORD | |
+================+=================================+=================+==============+============+===============+============+===============+====================+
| **Backend** | **Architecture** | **Input shape** | **# params** | **Recall** | **Precision** | **Recall** | **Precision** | **sec/it (B: 1)** |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | db_resnet50 | (1024, 1024, 3) | 25.2 M | 84.39 | 85.86 | 93.70 | 83.24 | 1.2 |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | db_mobilenet_v3_large | (1024, 1024, 3) | 4.2 M | 80.29 | 70.90 | 84.70 | 67.76 | 0.5 |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | linknet_resnet18 | (1024, 1024, 3) | 11.5 M | 81.37 | 84.08 | 85.71 | 83.70 | 0.7 |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | linknet_resnet34 | (1024, 1024, 3) | 21.6 M | 82.20 | 85.49 | 87.63 | 87.17 | 0.8 |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | linknet_resnet50 | (1024, 1024, 3) | 28.8 M | 80.70 | 83.51 | 86.46 | 84.94 | 1.1 |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | fast_tiny | (1024, 1024, 3) | 13.5 M | | | | | |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | fast_small | (1024, 1024, 3) | 14.7 M | | | | | |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | fast_base | (1024, 1024, 3) | 16.3 M | | | | | |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | db_resnet34 | (1024, 1024, 3) | 22.4 M | 82.76 | 76.75 | 89.20 | 71.74 | 0.8 |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | db_resnet50 | (1024, 1024, 3) | 25.4 M | 83.56 | 86.68 | 92.61 | 86.39 | 1.1 |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | db_mobilenet_v3_large | (1024, 1024, 3) | 4.2 M | 83.41 | 84.00 | 86.70 | 79.38 | 0.5 |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | linknet_resnet18 | (1024, 1024, 3) | 11.5 M | 81.64 | 85.52 | 88.92 | 82.74 | 0.6 |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | linknet_resnet34 | (1024, 1024, 3) | 21.6 M | 81.62 | 82.95 | 86.26 | 81.06 | 0.7 |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | linknet_resnet50 | (1024, 1024, 3) | 28.8 M | 81.78 | 82.47 | 87.29 | 85.54 | 1.0 |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | fast_tiny | (1024, 1024, 3) | 13.5 M | | | | | |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | fast_small | (1024, 1024, 3) | 14.7 M | | | | | |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | fast_base | (1024, 1024, 3) | 16.3 M | | | | | |
+----------------+---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+--------------------+
+------------------------------------------------------------------------------------+----------------------------+----------------------------+--------------------+
| | FUNSD | CORD | |
+================+=================================+=================+===============+============+===============+============+===============+====================+
| **Backend** | **Architecture** | **Input shape** | **# params** | **Recall** | **Precision** | **Recall** | **Precision** | **sec/it (B: 1)** |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | db_resnet50 | (1024, 1024, 3) | 25.2 M | 84.39 | 85.86 | 93.70 | 83.24 | 1.2 |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | db_mobilenet_v3_large | (1024, 1024, 3) | 4.2 M | 80.29 | 70.90 | 84.70 | 67.76 | 0.5 |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | linknet_resnet18 | (1024, 1024, 3) | 11.5 M | 81.37 | 84.08 | 85.71 | 83.70 | 0.7 |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | linknet_resnet34 | (1024, 1024, 3) | 21.6 M | 82.20 | 85.49 | 87.63 | 87.17 | 0.8 |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | linknet_resnet50 | (1024, 1024, 3) | 28.8 M | 80.70 | 83.51 | 86.46 | 84.94 | 1.1 |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | fast_tiny | (1024, 1024, 3) | 13.5 M (8.5M) | | | | | 0.7 (0.4) |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | fast_small | (1024, 1024, 3) | 14.7 M (9.7M) | | | | | 0.7 (0.5) |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| TensorFlow | fast_base | (1024, 1024, 3) | 16.3 M (10.6M)| | | | | 0.8 (0.5) |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | db_resnet34 | (1024, 1024, 3) | 22.4 M | 82.76 | 76.75 | 89.20 | 71.74 | 0.8 |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | db_resnet50 | (1024, 1024, 3) | 25.4 M | 83.56 | 86.68 | 92.61 | 86.39 | 1.1 |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | db_mobilenet_v3_large | (1024, 1024, 3) | 4.2 M | 83.41 | 84.00 | 86.70 | 79.38 | 0.5 |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | linknet_resnet18 | (1024, 1024, 3) | 11.5 M | 81.64 | 85.52 | 88.92 | 82.74 | 0.6 |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | linknet_resnet34 | (1024, 1024, 3) | 21.6 M | 81.62 | 82.95 | 86.26 | 81.06 | 0.7 |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | linknet_resnet50 | (1024, 1024, 3) | 28.8 M | 81.78 | 82.47 | 87.29 | 85.54 | 1.0 |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | fast_tiny | (1024, 1024, 3) | 13.5 M (8.5M) | | | | | 0.7 (0.4) |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | fast_small | (1024, 1024, 3) | 14.7 M (9.7M) | | | | | 0.7 (0.5) |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+
| PyTorch | fast_base | (1024, 1024, 3) | 16.3 M (10.6M)| | | | | 0.8 (0.5) |
+----------------+---------------------------------+-----------------+---------------+------------+---------------+------------+---------------+--------------------+


All text detection models above have been evaluated using both the training and evaluation sets of FUNSD and CORD (cf. :ref:`datasets`).
Expand Down
32 changes: 14 additions & 18 deletions doctr/models/detection/differentiable_binarization/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,24 +147,20 @@ def __init__(
_inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape]
output_shape = tuple(self.fpn(_inputs).shape)

self.probability_head = keras.Sequential(
[
*conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
layers.BatchNormalization(),
layers.Activation("relu"),
layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
]
)
self.threshold_head = keras.Sequential(
[
*conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
layers.BatchNormalization(),
layers.Activation("relu"),
layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
]
)
self.probability_head = keras.Sequential([
*conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
layers.BatchNormalization(),
layers.Activation("relu"),
layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
])
self.threshold_head = keras.Sequential([
*conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]),
layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"),
layers.BatchNormalization(),
layers.Activation("relu"),
layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"),
])

self.postprocessor = DBPostProcessor(
assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
Expand Down
45 changes: 43 additions & 2 deletions doctr/models/detection/fast/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import torch
Expand All @@ -18,7 +18,7 @@
from ...utils import _bf16_to_float32, load_pretrained_params
from .base import _FAST, FASTPostProcessor

__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base"]
__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]


default_cfgs: Dict[str, Dict[str, Any]] = {
Expand Down Expand Up @@ -279,6 +279,47 @@ def ohem_sample(score: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor) -> to
return text_loss + kernel_loss


def reparameterize(model: Union[FAST, nn.Module]) -> FAST:
"""Fuse batchnorm and conv layers and reparameterize the model
args:
----
model: the FAST model to reparameterize
Returns:
-------
the reparameterized model
"""
last_conv = None
last_conv_name = None

for module in model.modules():
if hasattr(module, "reparameterize_layer"):
module.reparameterize_layer()

for name, child in model.named_children():
if isinstance(child, nn.BatchNorm2d):
# fuse batchnorm only if it is followed by a conv layer
if last_conv is None:
continue
conv_w = last_conv.weight
conv_b = last_conv.bias if last_conv.bias is not None else torch.zeros_like(child.running_mean)

factor = child.weight / torch.sqrt(child.running_var + child.eps)
last_conv.weight = nn.Parameter(conv_w * factor.reshape([last_conv.out_channels, 1, 1, 1]))
last_conv.bias = nn.Parameter((conv_b - child.running_mean) * factor + child.bias)
model._modules[last_conv_name] = last_conv
model._modules[name] = nn.Identity()
last_conv = None
elif isinstance(child, nn.Conv2d):
last_conv = child
last_conv_name = name
else:
reparameterize(child)

return model # type: ignore[return-value]


def _fast(
arch: str,
pretrained: bool,
Expand Down
Loading

0 comments on commit 62d94ff

Please sign in to comment.