Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] allign orientation train script to current orientation model (counter clockwise instead of clockwise) & make OrientationPredictor dynamic #1559

Merged
merged 1 commit into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ doctr.models.classification

.. autofunction:: doctr.models.classification.mobilenet_v3_large_r

.. autofunction:: doctr.models.classification.mobilenet_v3_small_orientation
.. autofunction:: doctr.models.classification.mobilenet_v3_small_crop_orientation

.. autofunction:: doctr.models.classification.magc_resnet31

Expand Down
21 changes: 14 additions & 7 deletions doctr/models/classification/mobilenet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"mobilenet_v3_small_r",
"mobilenet_v3_large",
"mobilenet_v3_large_r",
"mobilenet_v3_small_orientation",
"mobilenet_v3_small_crop_orientation",
]

default_cfgs: Dict[str, Dict[str, Any]] = {
Expand Down Expand Up @@ -51,13 +51,20 @@
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-1a8a3530.pt&src=0",
},
"mobilenet_v3_small_orientation": {
"mobilenet_v3_small_crop_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 128, 128),
"classes": [0, 90, 180, 270],
"classes": [0, -90, 180, 90],
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/classif_mobilenet_v3_small-24f8ff57.pt&src=0",
},
"mobilenet_v3_small_page_orientation": {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need to expose it in zoo.py no ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct later on i added only the additional dict ftm that you have a place to put the url :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test integration etc then let n #1553

"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 512, 512),
"classes": [0, -90, 180, 90],
"url": None,
},
}


Expand Down Expand Up @@ -212,14 +219,14 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3
)


def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
"""MobileNetV3-Small architecture as described in
`"Searching for MobileNetV3",
<https://arxiv.org/pdf/1905.02244.pdf>`_.

>>> import torch
>>> from doctr.models import mobilenet_v3_small_orientation
>>> model = mobilenet_v3_small_orientation(pretrained=False)
>>> from doctr.models import mobilenet_v3_small_crop_orientation
>>> model = mobilenet_v3_small_crop_orientation(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
>>> out = model(input_tensor)

Expand All @@ -233,7 +240,7 @@ def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> m
a torch.nn.Module
"""
return _mobilenet_v3(
"mobilenet_v3_small_orientation",
"mobilenet_v3_small_crop_orientation",
pretrained,
ignore_keys=["classifier.3.weight", "classifier.3.bias"],
**kwargs,
Expand Down
21 changes: 14 additions & 7 deletions doctr/models/classification/mobilenet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"mobilenet_v3_small_r",
"mobilenet_v3_large",
"mobilenet_v3_large_r",
"mobilenet_v3_small_orientation",
"mobilenet_v3_small_crop_orientation",
]


Expand Down Expand Up @@ -54,13 +54,20 @@
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-3d61452e.zip&src=0",
},
"mobilenet_v3_small_orientation": {
"mobilenet_v3_small_crop_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (128, 128, 3),
"classes": [0, 90, 180, 270],
"classes": [0, -90, 180, 90],
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/classif_mobilenet_v3_small-1ea8db03.zip&src=0",
},
"mobilenet_v3_small_page_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (512, 512, 3),
"classes": [0, -90, 180, 90],
"url": None,
},
}


Expand Down Expand Up @@ -386,14 +393,14 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3
return _mobilenet_v3("mobilenet_v3_large_r", pretrained, True, **kwargs)


def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
"""MobileNetV3-Small architecture as described in
`"Searching for MobileNetV3",
<https://arxiv.org/pdf/1905.02244.pdf>`_.

>>> import tensorflow as tf
>>> from doctr.models import mobilenet_v3_small_orientation
>>> model = mobilenet_v3_small_orientation(pretrained=False)
>>> from doctr.models import mobilenet_v3_small_crop_orientation
>>> model = mobilenet_v3_small_crop_orientation(pretrained=False)
>>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
>>> out = model(input_tensor)

Expand All @@ -406,4 +413,4 @@ def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> M
-------
a keras.Model
"""
return _mobilenet_v3("mobilenet_v3_small_orientation", pretrained, include_top=True, **kwargs)
return _mobilenet_v3("mobilenet_v3_small_crop_orientation", pretrained, include_top=True, **kwargs)
10 changes: 3 additions & 7 deletions doctr/models/classification/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from doctr.models.preprocessor import PreProcessor
from doctr.models.utils import set_device_and_dtype

__all__ = ["CropOrientationPredictor"]
__all__ = ["OrientationPredictor"]


class CropOrientationPredictor(nn.Module):
class OrientationPredictor(nn.Module):
"""Implements an object able to detect the reading direction of a text box.
4 possible orientations: 0, 90, 180, 270 degrees counter clockwise.

Expand Down Expand Up @@ -57,11 +57,7 @@ def forward(
predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches]

class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
# Keep unified with page orientation range (counter clock rotation => negative) so 270 -> -90
classes = [
int(self.model.cfg["classes"][idx]) if int(self.model.cfg["classes"][idx]) != 270 else -90
for idx in class_idxs
]
classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs]
confs = [round(float(p), 2) for prob in probs for p in prob]

return [class_idxs, classes, confs]
10 changes: 3 additions & 7 deletions doctr/models/classification/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from doctr.models.preprocessor import PreProcessor
from doctr.utils.repr import NestedObject

__all__ = ["CropOrientationPredictor"]
__all__ = ["OrientationPredictor"]


class CropOrientationPredictor(NestedObject):
class OrientationPredictor(NestedObject):
"""Implements an object able to detect the reading direction of a text box.
4 possible orientations: 0, 90, 180, 270 degrees counter clockwise.

Expand Down Expand Up @@ -52,11 +52,7 @@ def __call__(
predicted_batches = [out_batch.numpy().argmax(1) for out_batch in predicted_batches]

class_idxs = [int(pred) for batch in predicted_batches for pred in batch]
# Keep unified with page orientation range (counter clock rotation => negative) so 270 -> -90
classes = [
int(self.model.cfg["classes"][idx]) if int(self.model.cfg["classes"][idx]) != 270 else -90
for idx in class_idxs
]
classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs]
confs = [round(float(p), 2) for prob in probs for p in prob]

return [class_idxs, classes, confs]
18 changes: 9 additions & 9 deletions doctr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .. import classification
from ..preprocessor import PreProcessor
from .predictor import CropOrientationPredictor
from .predictor import OrientationPredictor

__all__ = ["crop_orientation_predictor"]

Expand All @@ -31,10 +31,10 @@
"vit_s",
"vit_b",
]
ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_orientation"]
ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation"]


def _crop_orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> CropOrientationPredictor:
def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> OrientationPredictor:
if arch not in ORIENTATION_ARCHS:
raise ValueError(f"unknown architecture '{arch}'")

Expand All @@ -44,15 +44,15 @@ def _crop_orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> C
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
kwargs["batch_size"] = kwargs.get("batch_size", 128)
input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:]
predictor = CropOrientationPredictor(
predictor = OrientationPredictor(
PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
)
return predictor


def crop_orientation_predictor(
arch: str = "mobilenet_v3_small_orientation", pretrained: bool = False, **kwargs: Any
) -> CropOrientationPredictor:
arch: str = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any
) -> OrientationPredictor:
"""Orientation classification architecture.

>>> import numpy as np
Expand All @@ -65,10 +65,10 @@ def crop_orientation_predictor(
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small')
pretrained: If True, returns a model pre-trained on our recognition crops dataset
**kwargs: keyword arguments to be passed to the CropOrientationPredictor
**kwargs: keyword arguments to be passed to the OrientationPredictor

Returns:
-------
CropOrientationPredictor
OrientationPredictor
"""
return _crop_orientation_predictor(arch, pretrained, **kwargs)
return _orientation_predictor(arch, pretrained, **kwargs)
4 changes: 2 additions & 2 deletions doctr/models/kie_predictor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from doctr.models.builder import KIEDocumentBuilder

from ..classification.predictor import CropOrientationPredictor
from ..classification.predictor import OrientationPredictor
from ..predictor.base import _OCRPredictor

__all__ = ["_KIEPredictor"]
Expand All @@ -28,7 +28,7 @@ class _KIEPredictor(_OCRPredictor):
kwargs: keyword args of `DocumentBuilder`
"""

crop_orientation_predictor: Optional[CropOrientationPredictor]
crop_orientation_predictor: Optional[OrientationPredictor]

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/predictor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .._utils import rectify_crops, rectify_loc_preds
from ..classification import crop_orientation_predictor
from ..classification.predictor import CropOrientationPredictor
from ..classification.predictor import OrientationPredictor

__all__ = ["_OCRPredictor"]

Expand All @@ -32,7 +32,7 @@ class _OCRPredictor:
**kwargs: keyword args of `DocumentBuilder`
"""

crop_orientation_predictor: Optional[CropOrientationPredictor]
crop_orientation_predictor: Optional[OrientationPredictor]

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions references/classification/train_pytorch_orientation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from doctr.models.utils import export_model_to_onnx
from utils import EarlyStopper, plot_recorder, plot_samples

CLASSES = [0, 90, 180, 270]
CLASSES = [0, -90, 180, 90]


def rnd_rotate(img: torch.Tensor, target):
Expand Down Expand Up @@ -191,7 +191,7 @@ def main(args):

torch.backends.cudnn.benchmark = True

input_size = (256, 256) if args.type == "page" else (32, 32)
input_size = (512, 512) if args.type == "page" else (256, 256)

# Load val data generator
st = time.time()
Expand Down
4 changes: 2 additions & 2 deletions references/classification/train_tensorflow_orientation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from doctr.transforms.functional import rotated_img_tensor
from utils import EarlyStopper, plot_recorder, plot_samples

CLASSES = [0, 90, 180, 270]
CLASSES = [0, -90, 180, 90]


def rnd_rotate(img: tf.Tensor, target):
Expand Down Expand Up @@ -147,7 +147,7 @@ def main(args):
if not isinstance(args.workers, int):
args.workers = min(16, mp.cpu_count())

input_size = (256, 256) if args.type == "page" else (32, 32)
input_size = (512, 512) if args.type == "page" else (256, 256)

# AMP
if args.amp:
Expand Down
23 changes: 12 additions & 11 deletions tests/pytorch/test_models_classification_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

from doctr.models import classification
from doctr.models.classification.predictor import CropOrientationPredictor
from doctr.models.classification.predictor import OrientationPredictor
from doctr.models.utils import export_model_to_onnx


Expand Down Expand Up @@ -60,7 +60,7 @@ def test_classification_architectures(arch_name, input_shape, output_size):
@pytest.mark.parametrize(
"arch_name, input_shape",
[
["mobilenet_v3_small_orientation", (3, 128, 128)],
["mobilenet_v3_small_crop_orientation", (3, 128, 128)],
],
)
def test_classification_models(arch_name, input_shape):
Expand All @@ -80,7 +80,7 @@ def test_classification_models(arch_name, input_shape):
@pytest.mark.parametrize(
"arch_name",
[
"mobilenet_v3_small_orientation",
"mobilenet_v3_small_crop_orientation",
],
)
def test_classification_zoo(arch_name):
Expand All @@ -92,7 +92,7 @@ def test_classification_zoo(arch_name):
with pytest.raises(ValueError):
predictor = classification.zoo.crop_orientation_predictor(arch="wrong_model", pretrained=False)
# object check
assert isinstance(predictor, CropOrientationPredictor)
assert isinstance(predictor, OrientationPredictor)
input_tensor = torch.rand((batch_size, 3, 128, 128))
if torch.cuda.is_available():
predictor.model.cuda()
Expand All @@ -112,14 +112,15 @@ def test_classification_zoo(arch_name):

def test_crop_orientation_model(mock_text_box):
text_box_0 = cv2.imread(mock_text_box)
text_box_90 = np.rot90(text_box_0, 1)
# rotates counter-clockwise
text_box_270 = np.rot90(text_box_0, 1)
text_box_180 = np.rot90(text_box_0, 2)
text_box_270 = np.rot90(text_box_0, 3)
classifier = classification.crop_orientation_predictor("mobilenet_v3_small_orientation", pretrained=True)
assert classifier([text_box_0, text_box_90, text_box_180, text_box_270])[0] == [0, 1, 2, 3]
text_box_90 = np.rot90(text_box_0, 3)
classifier = classification.crop_orientation_predictor("mobilenet_v3_small_crop_orientation", pretrained=True)
assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[0] == [0, 1, 2, 3]
# 270 degrees is equivalent to -90 degrees
assert classifier([text_box_0, text_box_90, text_box_180, text_box_270])[1] == [0, 90, 180, -90]
assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_90, text_box_180, text_box_270])[2])
assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90]
assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2])


@pytest.mark.parametrize(
Expand All @@ -134,7 +135,7 @@ def test_crop_orientation_model(mock_text_box):
["magc_resnet31", (3, 32, 32), (126,)],
["mobilenet_v3_small", (3, 32, 32), (126,)],
["mobilenet_v3_large", (3, 32, 32), (126,)],
["mobilenet_v3_small_orientation", (3, 128, 128), (4,)],
["mobilenet_v3_small_crop_orientation", (3, 128, 128), (4,)],
["vit_s", (3, 32, 32), (126,)],
["vit_b", (3, 32, 32), (126,)],
["textnet_tiny", (3, 32, 32), (126,)],
Expand Down
Loading
Loading