Skip to content

Commit

Permalink
[models] Change Resize kwargs to args for each zoo predictor (#1765)
Browse files Browse the repository at this point in the history
  • Loading branch information
cmoscardi authored Nov 4, 2024
1 parent f0ab8c0 commit b411109
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 8 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Install all additional dependencies with the following command:

```shell
python -m pip install --upgrade pip
pip install -e .[dev]
pip install -e '.[dev]'
pre-commit install
```

Expand Down
10 changes: 6 additions & 4 deletions doctr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _orientation_predictor(


def crop_orientation_predictor(
arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any
arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, batch_size: int = 128, **kwargs: Any
) -> OrientationPredictor:
"""Crop orientation classification architecture.
Expand All @@ -77,17 +77,18 @@ def crop_orientation_predictor(
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
pretrained: If True, returns a model pre-trained on our recognition crops dataset
batch_size: number of samples the model processes in parallel
**kwargs: keyword arguments to be passed to the OrientationPredictor
Returns:
-------
OrientationPredictor
"""
return _orientation_predictor(arch, pretrained, model_type="crop", **kwargs)
return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="crop", **kwargs)


def page_orientation_predictor(
arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any
arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, batch_size: int = 4, **kwargs: Any
) -> OrientationPredictor:
"""Page orientation classification architecture.
Expand All @@ -101,10 +102,11 @@ def page_orientation_predictor(
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
pretrained: If True, returns a model pre-trained on our recognition crops dataset
batch_size: number of samples the model processes in parallel
**kwargs: keyword arguments to be passed to the OrientationPredictor
Returns:
-------
OrientationPredictor
"""
return _orientation_predictor(arch, pretrained, model_type="page", **kwargs)
return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="page", **kwargs)
17 changes: 16 additions & 1 deletion doctr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def detection_predictor(
arch: Any = "fast_base",
pretrained: bool = False,
assume_straight_pages: bool = True,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
batch_size: int = 2,
**kwargs: Any,
) -> DetectionPredictor:
"""Text detection architecture.
Expand All @@ -94,10 +97,22 @@ def detection_predictor(
arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
pretrained: If True, returns a model pre-trained on our text detection dataset
assume_straight_pages: If True, fit straight boxes to the page
preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before
running the detection model on it
symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
batch_size: number of samples the model processes in parallel
**kwargs: optional keyword arguments passed to the architecture
Returns:
-------
Detection predictor
"""
return _predictor(arch, pretrained, assume_straight_pages, **kwargs)
return _predictor(
arch=arch,
pretrained=pretrained,
assume_straight_pages=assume_straight_pages,
preserve_aspect_ratio=preserve_aspect_ratio,
symmetric_pad=symmetric_pad,
batch_size=batch_size,
**kwargs,
)
1 change: 1 addition & 0 deletions doctr/models/preprocessor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class PreProcessor(nn.Module):
batch_size: the size of page batches
mean: mean value of the training distribution by channel
std: standard deviation of the training distribution by channel
**kwargs: additional arguments for the resizing operation
"""

def __init__(
Expand Down
1 change: 1 addition & 0 deletions doctr/models/preprocessor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class PreProcessor(NestedObject):
batch_size: the size of page batches
mean: mean value of the training distribution by channel
std: standard deviation of the training distribution by channel
**kwargs: additional arguments for the resizing operation
"""

_children_names: List[str] = ["resize", "normalize"]
Expand Down
12 changes: 10 additions & 2 deletions doctr/models/recognition/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
return predictor


def recognition_predictor(arch: Any = "crnn_vgg16_bn", pretrained: bool = False, **kwargs: Any) -> RecognitionPredictor:
def recognition_predictor(
arch: Any = "crnn_vgg16_bn",
pretrained: bool = False,
symmetric_pad: bool = False,
batch_size: int = 128,
**kwargs: Any,
) -> RecognitionPredictor:
"""Text recognition architecture.
Example::
Expand All @@ -66,10 +72,12 @@ def recognition_predictor(arch: Any = "crnn_vgg16_bn", pretrained: bool = False,
----
arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn')
pretrained: If True, returns a model pre-trained on our text recognition dataset
symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
batch_size: number of samples the model processes in parallel
**kwargs: optional parameters to be passed to the architecture
Returns:
-------
Recognition predictor
"""
return _predictor(arch, pretrained, **kwargs)
return _predictor(arch=arch, pretrained=pretrained, symmetric_pad=symmetric_pad, batch_size=batch_size, **kwargs)

0 comments on commit b411109

Please sign in to comment.