forked from mindee/doctr
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PT / TF] Add TextNet - FAST backbone (mindee#1425)
- Loading branch information
1 parent
fab59af
commit a010972
Showing
14 changed files
with
768 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,4 +3,5 @@ | |
from .vgg import * | ||
from .magc_resnet import * | ||
from .vit import * | ||
from .textnet import * | ||
from .zoo import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from doctr.file_utils import is_tf_available, is_torch_available | ||
|
||
if is_tf_available(): | ||
from .tensorflow import * | ||
elif is_torch_available(): | ||
from .pytorch import * # type: ignore[assignment] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,275 @@ | ||
# Copyright (C) 2021-2024, Mindee. | ||
|
||
# This program is licensed under the Apache License 2.0. | ||
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | ||
|
||
|
||
from copy import deepcopy | ||
from typing import Any, Dict, List, Optional, Tuple | ||
|
||
from torch import nn | ||
|
||
from doctr.datasets import VOCABS | ||
|
||
from ...modules.layers.pytorch import FASTConvLayer | ||
from ...utils import conv_sequence_pt, load_pretrained_params | ||
|
||
__all__ = ["textnet_tiny", "textnet_small", "textnet_base"] | ||
|
||
default_cfgs: Dict[str, Dict[str, Any]] = { | ||
"textnet_tiny": { | ||
"mean": (0.694, 0.695, 0.693), | ||
"std": (0.299, 0.296, 0.301), | ||
"input_shape": (3, 32, 32), | ||
"classes": list(VOCABS["french"]), | ||
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_tiny-c23a1b9a.pt&src=0", | ||
}, | ||
"textnet_small": { | ||
"mean": (0.694, 0.695, 0.693), | ||
"std": (0.299, 0.296, 0.301), | ||
"input_shape": (3, 32, 32), | ||
"classes": list(VOCABS["french"]), | ||
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_small-775169f7.pt&src=0", | ||
}, | ||
"textnet_base": { | ||
"mean": (0.694, 0.695, 0.693), | ||
"std": (0.299, 0.296, 0.301), | ||
"input_shape": (3, 32, 32), | ||
"classes": list(VOCABS["french"]), | ||
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_base-6121c044.pt&src=0", | ||
}, | ||
} | ||
|
||
|
||
class TextNet(nn.Sequential): | ||
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with | ||
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_. | ||
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_. | ||
Args: | ||
---- | ||
stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage. | ||
include_top (bool, optional): Whether to include the classifier head. Defaults to True. | ||
num_classes (int, optional): Number of output classes. Defaults to 1000. | ||
cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
stages: List[Dict[str, List[int]]], | ||
input_shape: Tuple[int, int, int] = (3, 32, 32), | ||
num_classes: int = 1000, | ||
include_top: bool = True, | ||
cfg: Optional[Dict[str, Any]] = None, | ||
) -> None: | ||
_layers: List[nn.Module] = [ | ||
*conv_sequence_pt( | ||
in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2, padding=(1, 1) | ||
), | ||
nn.Sequential(*[ | ||
nn.Sequential(*[ | ||
FASTConvLayer(**params) # type: ignore[arg-type] | ||
for params in [{key: stage[key][i] for key in stage} for i in range(len(stage["in_channels"]))] | ||
]) | ||
for stage in stages | ||
]), | ||
] | ||
|
||
if include_top: | ||
_layers.append( | ||
nn.Sequential( | ||
nn.AdaptiveAvgPool2d(1), | ||
nn.Flatten(1), | ||
nn.Linear(stages[-1]["out_channels"][-1], num_classes), | ||
) | ||
) | ||
|
||
super().__init__(*_layers) | ||
self.cfg = cfg | ||
|
||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | ||
elif isinstance(m, nn.BatchNorm2d): | ||
nn.init.constant_(m.weight, 1) | ||
nn.init.constant_(m.bias, 0) | ||
|
||
|
||
def _textnet( | ||
arch: str, | ||
pretrained: bool, | ||
ignore_keys: Optional[List[str]] = None, | ||
**kwargs: Any, | ||
) -> TextNet: | ||
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) | ||
kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) | ||
|
||
_cfg = deepcopy(default_cfgs[arch]) | ||
_cfg["num_classes"] = kwargs["num_classes"] | ||
_cfg["classes"] = kwargs["classes"] | ||
kwargs.pop("classes") | ||
|
||
# Build the model | ||
model = TextNet(**kwargs) | ||
# Load pretrained parameters | ||
if pretrained: | ||
# The number of classes is not the same as the number of classes in the pretrained model => | ||
# remove the last layer weights | ||
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None | ||
load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) | ||
|
||
model.cfg = _cfg | ||
|
||
return model | ||
|
||
|
||
def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet: | ||
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with | ||
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_. | ||
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_. | ||
>>> import torch | ||
>>> from doctr.models import textnet_tiny | ||
>>> model = textnet_tiny(pretrained=False) | ||
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) | ||
>>> out = model(input_tensor) | ||
Args: | ||
---- | ||
pretrained: boolean, True if model is pretrained | ||
**kwargs: keyword arguments of the TextNet architecture | ||
Returns: | ||
------- | ||
A textnet tiny model | ||
""" | ||
return _textnet( | ||
"textnet_tiny", | ||
pretrained, | ||
stages=[ | ||
{"in_channels": [64] * 3, "out_channels": [64] * 3, "kernel_size": [(3, 3)] * 3, "stride": [1, 2, 1]}, | ||
{ | ||
"in_channels": [64, 128, 128, 128], | ||
"out_channels": [128] * 4, | ||
"kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1)], | ||
"stride": [2, 1, 1, 1], | ||
}, | ||
{ | ||
"in_channels": [128, 256, 256, 256], | ||
"out_channels": [256] * 4, | ||
"kernel_size": [(3, 3), (3, 3), (3, 1), (1, 3)], | ||
"stride": [2, 1, 1, 1], | ||
}, | ||
{ | ||
"in_channels": [256, 512, 512, 512], | ||
"out_channels": [512] * 4, | ||
"kernel_size": [(3, 3), (3, 1), (1, 3), (3, 3)], | ||
"stride": [2, 1, 1, 1], | ||
}, | ||
], | ||
ignore_keys=["4.2.weight", "4.2.bias"], | ||
**kwargs, | ||
) | ||
|
||
|
||
def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet: | ||
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with | ||
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_. | ||
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_. | ||
>>> import torch | ||
>>> from doctr.models import textnet_small | ||
>>> model = textnet_small(pretrained=False) | ||
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) | ||
>>> out = model(input_tensor) | ||
Args: | ||
---- | ||
pretrained: boolean, True if model is pretrained | ||
**kwargs: keyword arguments of the TextNet architecture | ||
Returns: | ||
------- | ||
A TextNet small model | ||
""" | ||
return _textnet( | ||
"textnet_small", | ||
pretrained, | ||
stages=[ | ||
{"in_channels": [64] * 2, "out_channels": [64] * 2, "kernel_size": [(3, 3)] * 2, "stride": [1, 2]}, | ||
{ | ||
"in_channels": [64, 128, 128, 128, 128, 128, 128, 128], | ||
"out_channels": [128] * 8, | ||
"kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1), (1, 3), (3, 3)], | ||
"stride": [2, 1, 1, 1, 1, 1, 1, 1], | ||
}, | ||
{ | ||
"in_channels": [128, 256, 256, 256, 256, 256, 256, 256], | ||
"out_channels": [256] * 8, | ||
"kernel_size": [(3, 3), (3, 3), (1, 3), (3, 1), (3, 3), (1, 3), (3, 1), (3, 3)], | ||
"stride": [2, 1, 1, 1, 1, 1, 1, 1], | ||
}, | ||
{ | ||
"in_channels": [256, 512, 512, 512, 512], | ||
"out_channels": [512] * 5, | ||
"kernel_size": [(3, 3), (3, 1), (1, 3), (1, 3), (3, 1)], | ||
"stride": [2, 1, 1, 1, 1], | ||
}, | ||
], | ||
ignore_keys=["4.2.weight", "4.2.bias"], | ||
**kwargs, | ||
) | ||
|
||
|
||
def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet: | ||
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with | ||
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_. | ||
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_. | ||
>>> import torch | ||
>>> from doctr.models import textnet_base | ||
>>> model = textnet_base(pretrained=False) | ||
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) | ||
>>> out = model(input_tensor) | ||
Args: | ||
---- | ||
pretrained: boolean, True if model is pretrained | ||
**kwargs: keyword arguments of the TextNet architecture | ||
Returns: | ||
------- | ||
A TextNet base model | ||
""" | ||
return _textnet( | ||
"textnet_base", | ||
pretrained, | ||
stages=[ | ||
{ | ||
"in_channels": [64] * 10, | ||
"out_channels": [64] * 10, | ||
"kernel_size": [(3, 3), (3, 3), (3, 1), (3, 3), (3, 1), (3, 3), (3, 3), (1, 3), (3, 3), (3, 3)], | ||
"stride": [1, 2, 1, 1, 1, 1, 1, 1, 1, 1], | ||
}, | ||
{ | ||
"in_channels": [64, 128, 128, 128, 128, 128, 128, 128, 128, 128], | ||
"out_channels": [128] * 10, | ||
"kernel_size": [(3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 3), (3, 1), (3, 1), (3, 3), (3, 3)], | ||
"stride": [2, 1, 1, 1, 1, 1, 1, 1, 1, 1], | ||
}, | ||
{ | ||
"in_channels": [128, 256, 256, 256, 256, 256, 256, 256], | ||
"out_channels": [256] * 8, | ||
"kernel_size": [(3, 3), (3, 3), (3, 3), (1, 3), (3, 3), (3, 1), (3, 3), (3, 1)], | ||
"stride": [2, 1, 1, 1, 1, 1, 1, 1], | ||
}, | ||
{ | ||
"in_channels": [256, 512, 512, 512, 512], | ||
"out_channels": [512] * 5, | ||
"kernel_size": [(3, 3), (1, 3), (3, 1), (3, 1), (1, 3)], | ||
"stride": [2, 1, 1, 1, 1], | ||
}, | ||
], | ||
ignore_keys=["4.2.weight", "4.2.bias"], | ||
**kwargs, | ||
) |
Oops, something went wrong.