Skip to content

Commit

Permalink
[PT / TF] Add TextNet - FAST backbone (mindee#1425)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Jan 22, 2024
1 parent fab59af commit a010972
Show file tree
Hide file tree
Showing 14 changed files with 768 additions and 1 deletion.
6 changes: 6 additions & 0 deletions docs/source/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ doctr.models.classification

.. autofunction:: doctr.models.classification.vit_b

.. autofunction:: doctr.models.classification.textnet_tiny

.. autofunction:: doctr.models.classification.textnet_small

.. autofunction:: doctr.models.classification.textnet_base

.. autofunction:: doctr.models.classification.crop_orientation_predictor


Expand Down
1 change: 1 addition & 0 deletions doctr/models/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
from .vgg import *
from .magc_resnet import *
from .vit import *
from .textnet import *
from .zoo import *
6 changes: 6 additions & 0 deletions doctr/models/classification/textnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from doctr.file_utils import is_tf_available, is_torch_available

if is_tf_available():
from .tensorflow import *
elif is_torch_available():
from .pytorch import * # type: ignore[assignment]
275 changes: 275 additions & 0 deletions doctr/models/classification/textnet/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
# Copyright (C) 2021-2024, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.


from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple

from torch import nn

from doctr.datasets import VOCABS

from ...modules.layers.pytorch import FASTConvLayer
from ...utils import conv_sequence_pt, load_pretrained_params

__all__ = ["textnet_tiny", "textnet_small", "textnet_base"]

default_cfgs: Dict[str, Dict[str, Any]] = {
"textnet_tiny": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_tiny-c23a1b9a.pt&src=0",
},
"textnet_small": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_small-775169f7.pt&src=0",
},
"textnet_base": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_base-6121c044.pt&src=0",
},
}


class TextNet(nn.Sequential):
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
Args:
----
stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage.
include_top (bool, optional): Whether to include the classifier head. Defaults to True.
num_classes (int, optional): Number of output classes. Defaults to 1000.
cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None.
"""

def __init__(
self,
stages: List[Dict[str, List[int]]],
input_shape: Tuple[int, int, int] = (3, 32, 32),
num_classes: int = 1000,
include_top: bool = True,
cfg: Optional[Dict[str, Any]] = None,
) -> None:
_layers: List[nn.Module] = [
*conv_sequence_pt(
in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2, padding=(1, 1)
),
nn.Sequential(*[
nn.Sequential(*[
FASTConvLayer(**params) # type: ignore[arg-type]
for params in [{key: stage[key][i] for key in stage} for i in range(len(stage["in_channels"]))]
])
for stage in stages
]),
]

if include_top:
_layers.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(1),
nn.Linear(stages[-1]["out_channels"][-1], num_classes),
)
)

super().__init__(*_layers)
self.cfg = cfg

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)


def _textnet(
arch: str,
pretrained: bool,
ignore_keys: Optional[List[str]] = None,
**kwargs: Any,
) -> TextNet:
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])

_cfg = deepcopy(default_cfgs[arch])
_cfg["num_classes"] = kwargs["num_classes"]
_cfg["classes"] = kwargs["classes"]
kwargs.pop("classes")

# Build the model
model = TextNet(**kwargs)
# Load pretrained parameters
if pretrained:
# The number of classes is not the same as the number of classes in the pretrained model =>
# remove the last layer weights
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)

model.cfg = _cfg

return model


def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet:
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
>>> import torch
>>> from doctr.models import textnet_tiny
>>> model = textnet_tiny(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
>>> out = model(input_tensor)
Args:
----
pretrained: boolean, True if model is pretrained
**kwargs: keyword arguments of the TextNet architecture
Returns:
-------
A textnet tiny model
"""
return _textnet(
"textnet_tiny",
pretrained,
stages=[
{"in_channels": [64] * 3, "out_channels": [64] * 3, "kernel_size": [(3, 3)] * 3, "stride": [1, 2, 1]},
{
"in_channels": [64, 128, 128, 128],
"out_channels": [128] * 4,
"kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1)],
"stride": [2, 1, 1, 1],
},
{
"in_channels": [128, 256, 256, 256],
"out_channels": [256] * 4,
"kernel_size": [(3, 3), (3, 3), (3, 1), (1, 3)],
"stride": [2, 1, 1, 1],
},
{
"in_channels": [256, 512, 512, 512],
"out_channels": [512] * 4,
"kernel_size": [(3, 3), (3, 1), (1, 3), (3, 3)],
"stride": [2, 1, 1, 1],
},
],
ignore_keys=["4.2.weight", "4.2.bias"],
**kwargs,
)


def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet:
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
>>> import torch
>>> from doctr.models import textnet_small
>>> model = textnet_small(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
>>> out = model(input_tensor)
Args:
----
pretrained: boolean, True if model is pretrained
**kwargs: keyword arguments of the TextNet architecture
Returns:
-------
A TextNet small model
"""
return _textnet(
"textnet_small",
pretrained,
stages=[
{"in_channels": [64] * 2, "out_channels": [64] * 2, "kernel_size": [(3, 3)] * 2, "stride": [1, 2]},
{
"in_channels": [64, 128, 128, 128, 128, 128, 128, 128],
"out_channels": [128] * 8,
"kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1), (1, 3), (3, 3)],
"stride": [2, 1, 1, 1, 1, 1, 1, 1],
},
{
"in_channels": [128, 256, 256, 256, 256, 256, 256, 256],
"out_channels": [256] * 8,
"kernel_size": [(3, 3), (3, 3), (1, 3), (3, 1), (3, 3), (1, 3), (3, 1), (3, 3)],
"stride": [2, 1, 1, 1, 1, 1, 1, 1],
},
{
"in_channels": [256, 512, 512, 512, 512],
"out_channels": [512] * 5,
"kernel_size": [(3, 3), (3, 1), (1, 3), (1, 3), (3, 1)],
"stride": [2, 1, 1, 1, 1],
},
],
ignore_keys=["4.2.weight", "4.2.bias"],
**kwargs,
)


def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet:
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
>>> import torch
>>> from doctr.models import textnet_base
>>> model = textnet_base(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
>>> out = model(input_tensor)
Args:
----
pretrained: boolean, True if model is pretrained
**kwargs: keyword arguments of the TextNet architecture
Returns:
-------
A TextNet base model
"""
return _textnet(
"textnet_base",
pretrained,
stages=[
{
"in_channels": [64] * 10,
"out_channels": [64] * 10,
"kernel_size": [(3, 3), (3, 3), (3, 1), (3, 3), (3, 1), (3, 3), (3, 3), (1, 3), (3, 3), (3, 3)],
"stride": [1, 2, 1, 1, 1, 1, 1, 1, 1, 1],
},
{
"in_channels": [64, 128, 128, 128, 128, 128, 128, 128, 128, 128],
"out_channels": [128] * 10,
"kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 3), (3, 1), (3, 1), (3, 3), (3, 3)],
"stride": [2, 1, 1, 1, 1, 1, 1, 1, 1, 1],
},
{
"in_channels": [128, 256, 256, 256, 256, 256, 256, 256],
"out_channels": [256] * 8,
"kernel_size": [(3, 3), (3, 3), (3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1)],
"stride": [2, 1, 1, 1, 1, 1, 1, 1],
},
{
"in_channels": [256, 512, 512, 512, 512],
"out_channels": [512] * 5,
"kernel_size": [(3, 3), (1, 3), (3, 1), (3, 1), (1, 3)],
"stride": [2, 1, 1, 1, 1],
},
],
ignore_keys=["4.2.weight", "4.2.bias"],
**kwargs,
)
Loading

0 comments on commit a010972

Please sign in to comment.