From c56bf413c7b580b069a029b235eb032f94b4256a Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Mon, 18 Nov 2024 10:36:27 +0100 Subject: [PATCH] [misc] Change prefered backend from tf to torch (#1779) --- .github/workflows/docs.yml | 2 +- .pre-commit-config.yaml | 2 +- api/Makefile | 2 +- api/app/utils.py | 2 - api/app/vision.py | 30 +-- api/docker-compose.yml | 2 - api/pyproject.toml | 2 +- api/tests/conftest.py | 201 +++++++++--------- demo/backend/pytorch.py | 4 - demo/backend/tensorflow.py | 4 - docs/source/conf.py | 2 +- docs/source/getting_started/installing.rst | 21 +- docs/source/modules/datasets.rst | 19 +- docs/source/modules/transforms.rst | 23 +- docs/source/modules/utils.rst | 5 + .../using_doctr/custom_models_training.rst | 90 ++++---- docs/source/using_doctr/sharing_models.rst | 20 +- .../source/using_doctr/using_model_export.rst | 33 ++- doctr/contrib/__init__.py | 1 + doctr/contrib/artefacts.py | 2 - doctr/contrib/base.py | 9 - doctr/datasets/cord.py | 1 - doctr/datasets/datasets/__init__.py | 8 +- doctr/datasets/datasets/base.py | 1 - doctr/datasets/detection.py | 3 - doctr/datasets/doc_artefacts.py | 1 - doctr/datasets/funsd.py | 1 - doctr/datasets/generator/__init__.py | 8 +- doctr/datasets/generator/base.py | 2 - doctr/datasets/generator/pytorch.py | 2 - doctr/datasets/generator/tensorflow.py | 2 - doctr/datasets/ic03.py | 1 - doctr/datasets/ic13.py | 1 - doctr/datasets/iiit5k.py | 1 - doctr/datasets/iiithws.py | 1 - doctr/datasets/imgur5k.py | 1 - doctr/datasets/loader.py | 3 - doctr/datasets/mjsynth.py | 1 - doctr/datasets/ocr.py | 1 - doctr/datasets/orientation.py | 1 - doctr/datasets/recognition.py | 1 - doctr/datasets/sroie.py | 1 - doctr/datasets/svhn.py | 1 - doctr/datasets/svt.py | 1 - doctr/datasets/synthtext.py | 1 - doctr/datasets/utils.py | 35 +-- doctr/datasets/wildreceipt.py | 8 +- doctr/file_utils.py | 1 - doctr/io/elements.py | 24 +-- doctr/io/html.py | 2 - doctr/io/image/__init__.py | 6 +- doctr/io/image/base.py | 2 - doctr/io/image/pytorch.py | 8 - doctr/io/image/tensorflow.py | 8 - doctr/io/pdf.py | 2 - doctr/io/reader.py | 6 - doctr/models/_utils.py | 8 - doctr/models/builder.py | 18 -- .../classification/magc_resnet/__init__.py | 6 +- .../classification/magc_resnet/pytorch.py | 3 - .../classification/magc_resnet/tensorflow.py | 3 - .../classification/mobilenet/__init__.py | 6 +- .../classification/mobilenet/pytorch.py | 12 -- .../classification/mobilenet/tensorflow.py | 13 -- .../classification/predictor/__init__.py | 8 +- .../classification/predictor/pytorch.py | 1 - .../classification/predictor/tensorflow.py | 1 - .../models/classification/resnet/__init__.py | 8 +- doctr/models/classification/resnet/pytorch.py | 11 - .../classification/resnet/tensorflow.py | 12 -- .../models/classification/textnet/__init__.py | 6 +- .../models/classification/textnet/pytorch.py | 7 - .../classification/textnet/tensorflow.py | 7 - doctr/models/classification/vgg/__init__.py | 6 +- doctr/models/classification/vgg/pytorch.py | 2 - doctr/models/classification/vgg/tensorflow.py | 3 - doctr/models/classification/vit/__init__.py | 6 +- doctr/models/classification/vit/pytorch.py | 6 - doctr/models/classification/vit/tensorflow.py | 6 - doctr/models/classification/zoo.py | 4 - doctr/models/detection/_utils/__init__.py | 8 +- doctr/models/detection/_utils/base.py | 2 - doctr/models/detection/_utils/pytorch.py | 4 - doctr/models/detection/_utils/tensorflow.py | 4 - doctr/models/detection/core.py | 5 - .../differentiable_binarization/__init__.py | 8 +- .../differentiable_binarization/base.py | 9 - .../differentiable_binarization/pytorch.py | 9 - .../differentiable_binarization/tensorflow.py | 10 - doctr/models/detection/fast/__init__.py | 8 +- doctr/models/detection/fast/base.py | 7 - doctr/models/detection/fast/pytorch.py | 15 +- doctr/models/detection/fast/tensorflow.py | 14 +- doctr/models/detection/linknet/__init__.py | 8 +- doctr/models/detection/linknet/base.py | 8 - doctr/models/detection/linknet/pytorch.py | 9 - doctr/models/detection/linknet/tensorflow.py | 9 - doctr/models/detection/predictor/__init__.py | 10 +- doctr/models/detection/predictor/pytorch.py | 1 - .../models/detection/predictor/tensorflow.py | 1 - doctr/models/detection/zoo.py | 2 - doctr/models/factory/hub.py | 8 +- doctr/models/kie_predictor/__init__.py | 10 +- doctr/models/kie_predictor/base.py | 1 - doctr/models/kie_predictor/pytorch.py | 5 +- doctr/models/kie_predictor/tensorflow.py | 3 +- doctr/models/modules/layers/__init__.py | 6 +- doctr/models/modules/transformer/__init__.py | 6 +- doctr/models/modules/transformer/pytorch.py | 4 +- .../models/modules/transformer/tensorflow.py | 4 +- .../modules/vision_transformer/__init__.py | 6 +- doctr/models/predictor/__init__.py | 10 +- doctr/models/predictor/base.py | 2 - doctr/models/predictor/pytorch.py | 5 +- doctr/models/predictor/tensorflow.py | 5 +- doctr/models/preprocessor/__init__.py | 8 +- doctr/models/preprocessor/pytorch.py | 5 - doctr/models/preprocessor/tensorflow.py | 5 - doctr/models/recognition/core.py | 3 - doctr/models/recognition/crnn/__init__.py | 8 +- doctr/models/recognition/crnn/pytorch.py | 14 -- doctr/models/recognition/crnn/tensorflow.py | 12 -- doctr/models/recognition/master/__init__.py | 6 +- doctr/models/recognition/master/base.py | 3 - doctr/models/recognition/master/pytorch.py | 9 - doctr/models/recognition/master/tensorflow.py | 10 - doctr/models/recognition/parseq/__init__.py | 6 +- doctr/models/recognition/parseq/base.py | 3 - doctr/models/recognition/parseq/pytorch.py | 6 - doctr/models/recognition/parseq/tensorflow.py | 7 +- .../models/recognition/predictor/__init__.py | 10 +- doctr/models/recognition/predictor/_utils.py | 2 - doctr/models/recognition/predictor/pytorch.py | 3 +- .../recognition/predictor/tensorflow.py | 1 - doctr/models/recognition/sar/__init__.py | 8 +- doctr/models/recognition/sar/pytorch.py | 7 - doctr/models/recognition/sar/tensorflow.py | 9 - doctr/models/recognition/utils.py | 4 - doctr/models/recognition/vitstr/__init__.py | 8 +- doctr/models/recognition/vitstr/base.py | 3 - doctr/models/recognition/vitstr/pytorch.py | 8 - doctr/models/recognition/vitstr/tensorflow.py | 8 - doctr/models/recognition/zoo.py | 2 - doctr/models/utils/__init__.py | 8 +- doctr/models/utils/pytorch.py | 7 - doctr/models/utils/tensorflow.py | 7 - doctr/models/zoo.py | 4 - doctr/transforms/functional/__init__.py | 6 +- doctr/transforms/functional/base.py | 6 - doctr/transforms/functional/pytorch.py | 8 - doctr/transforms/functional/tensorflow.py | 12 -- doctr/transforms/modules/__init__.py | 8 +- doctr/transforms/modules/base.py | 74 +++---- doctr/transforms/modules/pytorch.py | 7 +- doctr/transforms/modules/tensorflow.py | 19 +- doctr/utils/data.py | 3 - doctr/utils/fonts.py | 2 - doctr/utils/geometry.py | 30 --- doctr/utils/metrics.py | 27 +-- doctr/utils/multithreading.py | 3 - doctr/utils/reconstitution.py | 4 - doctr/utils/visualization.py | 13 -- pyproject.toml | 2 +- references/classification/utils.py | 1 - references/detection/utils.py | 1 - references/recognition/train_pytorch_ddp.py | 4 +- references/recognition/utils.py | 1 - 167 files changed, 431 insertions(+), 1016 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 2b2c056a99..5e228fb444 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -27,7 +27,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e .[tf,viz,html] + pip install -e .[torch,viz,html] pip install -e .[docs] - name: Build documentation diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bbbdbdf2b1..5b283e9a6a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: - id: no-commit-to-branch args: ['--branch', 'main'] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.1 + rev: v0.7.4 hooks: - id: ruff args: [ --fix ] diff --git a/api/Makefile b/api/Makefile index 09e9841e91..d8044e2126 100644 --- a/api/Makefile +++ b/api/Makefile @@ -3,7 +3,7 @@ .PHONY: lock run stop test # Pin the dependencies lock: - pip install poetry>=1.0 + pip install poetry>=1.0 poetry-plugin-export poetry lock poetry export -f requirements.txt --without-hashes --output requirements.txt poetry export -f requirements.txt --without-hashes --with dev --output requirements-dev.txt diff --git a/api/app/utils.py b/api/app/utils.py index 511a75ad9e..472bcb2985 100644 --- a/api/app/utils.py +++ b/api/app/utils.py @@ -24,11 +24,9 @@ async def get_documents(files: List[UploadFile]) -> Tuple[List[np.ndarray], List """Convert a list of UploadFile objects to lists of numpy arrays and their corresponding filenames Args: - ---- files: list of UploadFile objects Returns: - ------- Tuple[List[np.ndarray], List[str]]: list of numpy arrays and their corresponding filenames """ diff --git a/api/app/vision.py b/api/app/vision.py index 144b5e4c3b..99f9c5e8e2 100644 --- a/api/app/vision.py +++ b/api/app/vision.py @@ -4,28 +4,34 @@ # See LICENSE or go to for full license details. -import tensorflow as tf - -gpu_devices = tf.config.list_physical_devices("GPU") -if any(gpu_devices): - tf.config.experimental.set_memory_growth(gpu_devices[0], True) - from typing import Callable, Union +import torch + from doctr.models import kie_predictor, ocr_predictor from .schemas import DetectionIn, KIEIn, OCRIn, RecognitionIn +def _move_to_device(predictor: Callable) -> Callable: + """Move the predictor to the desired device + + Args: + predictor: the predictor to move + + Returns: + Callable: the predictor moved to the desired device + """ + return predictor.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) + + def init_predictor(request: Union[KIEIn, OCRIn, RecognitionIn, DetectionIn]) -> Callable: """Initialize the predictor based on the request Args: - ---- request: input request Returns: - ------- Callable: the predictor """ params = request.model_dump() @@ -36,12 +42,12 @@ def init_predictor(request: Union[KIEIn, OCRIn, RecognitionIn, DetectionIn]) -> predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh predictor.det_predictor.model.postprocessor.box_thresh = box_thresh if isinstance(request, DetectionIn): - return predictor.det_predictor + return _move_to_device(predictor.det_predictor) elif isinstance(request, RecognitionIn): - return predictor.reco_predictor - return predictor + return _move_to_device(predictor.reco_predictor) + return _move_to_device(predictor) elif isinstance(request, KIEIn): predictor = kie_predictor(pretrained=True, **params) predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh predictor.det_predictor.model.postprocessor.box_thresh = box_thresh - return predictor + return _move_to_device(predictor) diff --git a/api/docker-compose.yml b/api/docker-compose.yml index 4140ed9cbb..ba1c129547 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -1,5 +1,3 @@ -version: '3.8' - services: web: container_name: api_web diff --git a/api/pyproject.toml b/api/pyproject.toml index 3f21402540..7ec2f42c0f 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -11,7 +11,7 @@ license = "Apache-2.0" [tool.poetry.dependencies] python = ">=3.10,<3.13" -python-doctr = {git = "https://github.com/mindee/doctr.git", extras = ['tf'], branch = "main" } +python-doctr = {git = "https://github.com/mindee/doctr.git", extras = ['torch'], branch = "main" } # Fastapi: minimum version required to avoid pydantic error # cf. https://github.com/tiangolo/fastapi/issues/4168 fastapi = ">=0.73.0" diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 5fee0d5894..dedfcae2b2 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -37,32 +37,32 @@ def mock_detection_response(): "box": { "name": "117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg", "geometries": [ - [0.8176307908857315, 0.1787109375, 0.9101580212741838, 0.2080078125], - [0.7471996155154171, 0.1796875, 0.8272978149561669, 0.20703125], + [0.8203927977629988, 0.181640625, 0.9087770178355502, 0.2041015625], + [0.7471996155154171, 0.1806640625, 0.8245358080788996, 0.2060546875], ], }, "poly": { "name": "117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg", "geometries": [ [ - 0.9063061475753784, - 0.17740710079669952, - 0.9078840017318726, - 0.20474515855312347, - 0.8173396587371826, - 0.20735852420330048, - 0.8157618045806885, - 0.18002046644687653, + 0.8203927977629988, + 0.181640625, + 0.906015010958283, + 0.181640625, + 0.906015010958283, + 0.2021484375, + 0.8203927977629988, + 0.2021484375, ], [ - 0.8233299851417542, - 0.17740298807621002, - 0.8250390291213989, - 0.2027825564146042, - 0.7470247745513916, - 0.20540954172611237, - 0.7453157305717468, - 0.1800299733877182, + 0.7482568619833604, + 0.17938309907913208, + 0.8208542842026056, + 0.1819499135017395, + 0.8193355512950555, + 0.2034294307231903, + 0.7467381290758103, + 0.20086261630058289, ], ], }, @@ -82,17 +82,17 @@ def mock_kie_response(): "class_name": "words", "items": [ { - "value": "Hello", - "geometry": [0.7471996155154171, 0.1796875, 0.8272978149561669, 0.20703125], - "objectness_score": 0.39, - "confidence": 1, + "value": "world!", + "geometry": [0.8203927977629988, 0.181640625, 0.9087770178355502, 0.2041015625], + "objectness_score": 0.46, + "confidence": 0.94, "crop_orientation": {"value": 0, "confidence": None}, }, { - "value": "world!", - "geometry": [0.8176307908857315, 0.1787109375, 0.9101580212741838, 0.2080078125], - "objectness_score": 0.39, - "confidence": 1, + "value": "Hello", + "geometry": [0.7471996155154171, 0.1806640625, 0.8245358080788996, 0.2060546875], + "objectness_score": 0.46, + "confidence": 0.66, "crop_orientation": {"value": 0, "confidence": None}, }, ], @@ -109,35 +109,35 @@ def mock_kie_response(): "class_name": "words", "items": [ { - "value": "Hello", + "value": "world!", "geometry": [ - 0.7453157305717468, - 0.1800299733877182, - 0.8233299851417542, - 0.17740298807621002, - 0.8250390291213989, - 0.2027825564146042, - 0.7470247745513916, - 0.20540954172611237, + 0.8203927977629988, + 0.181640625, + 0.906015010958283, + 0.181640625, + 0.906015010958283, + 0.2021484375, + 0.8203927977629988, + 0.2021484375, ], - "objectness_score": 0.5, - "confidence": 0.99, + "objectness_score": 0.52, + "confidence": 1, "crop_orientation": {"value": 0, "confidence": 1}, }, { - "value": "world!", + "value": "Hello", "geometry": [ - 0.8157618045806885, - 0.18002046644687653, - 0.9063061475753784, - 0.17740710079669952, - 0.9078840017318726, - 0.20474515855312347, - 0.8173396587371826, - 0.20735852420330048, + 0.7482568619833604, + 0.17938309907913208, + 0.8208542842026056, + 0.1819499135017395, + 0.8193355512950555, + 0.2034294307231903, + 0.7467381290758103, + 0.20086261630058289, ], - "objectness_score": 0.5, - "confidence": 1, + "objectness_score": 0.57, + "confidence": 0.65, "crop_orientation": {"value": 0, "confidence": 1}, }, ], @@ -159,30 +159,35 @@ def mock_ocr_response(): { "blocks": [ { - "geometry": [0.7471996155154171, 0.1787109375, 0.9101580212741838, 0.2080078125], - "objectness_score": 0.39, + "geometry": [0.7471996155154171, 0.1806640625, 0.9087770178355502, 0.2060546875], + "objectness_score": 0.46, "lines": [ { - "geometry": [0.7471996155154171, 0.1787109375, 0.9101580212741838, 0.2080078125], - "objectness_score": 0.39, + "geometry": [0.7471996155154171, 0.1806640625, 0.9087770178355502, 0.2060546875], + "objectness_score": 0.46, "words": [ { "value": "Hello", - "geometry": [0.7471996155154171, 0.1796875, 0.8272978149561669, 0.20703125], - "objectness_score": 0.39, - "confidence": 1, + "geometry": [ + 0.7471996155154171, + 0.1806640625, + 0.8245358080788996, + 0.2060546875, + ], + "objectness_score": 0.46, + "confidence": 0.66, "crop_orientation": {"value": 0, "confidence": None}, }, { "value": "world!", "geometry": [ - 0.8176307908857315, - 0.1787109375, - 0.9101580212741838, - 0.2080078125, + 0.8203927977629988, + 0.181640625, + 0.9087770178355502, + 0.2041015625, ], - "objectness_score": 0.39, - "confidence": 1, + "objectness_score": 0.46, + "confidence": 0.94, "crop_orientation": {"value": 0, "confidence": None}, }, ], @@ -203,59 +208,59 @@ def mock_ocr_response(): "blocks": [ { "geometry": [ - 0.7451040148735046, - 0.17927837371826172, - 0.9062581658363342, - 0.17407986521720886, - 0.9072266221046448, - 0.2041015625, - 0.7460724711418152, - 0.20930007100105286, + 0.7460642457008362, + 0.2017778754234314, + 0.7464945912361145, + 0.17868199944496155, + 0.9056554436683655, + 0.18164771795272827, + 0.9052250981330872, + 0.20474359393119812, ], - "objectness_score": 0.5, + "objectness_score": 0.54, "lines": [ { "geometry": [ - 0.7451040148735046, - 0.17927837371826172, - 0.9062581658363342, - 0.17407986521720886, - 0.9072266221046448, - 0.2041015625, - 0.7460724711418152, - 0.20930007100105286, + 0.7460642457008362, + 0.2017778754234314, + 0.7464945912361145, + 0.17868199944496155, + 0.9056554436683655, + 0.18164771795272827, + 0.9052250981330872, + 0.20474359393119812, ], - "objectness_score": 0.5, + "objectness_score": 0.54, "words": [ { "value": "Hello", "geometry": [ - 0.7453157305717468, - 0.1800299733877182, - 0.8233299851417542, - 0.17740298807621002, - 0.8250390291213989, - 0.2027825564146042, - 0.7470247745513916, - 0.20540954172611237, + 0.7482568619833604, + 0.17938309907913208, + 0.8208542842026056, + 0.1819499135017395, + 0.8193355512950555, + 0.2034294307231903, + 0.7467381290758103, + 0.20086261630058289, ], - "objectness_score": 0.5, - "confidence": 0.99, + "objectness_score": 0.57, + "confidence": 0.65, "crop_orientation": {"value": 0, "confidence": 1}, }, { "value": "world!", "geometry": [ - 0.8157618045806885, - 0.18002046644687653, - 0.9063061475753784, - 0.17740710079669952, - 0.9078840017318726, - 0.20474515855312347, - 0.8173396587371826, - 0.20735852420330048, + 0.8203927977629988, + 0.181640625, + 0.906015010958283, + 0.181640625, + 0.906015010958283, + 0.2021484375, + 0.8203927977629988, + 0.2021484375, ], - "objectness_score": 0.5, + "objectness_score": 0.52, "confidence": 1, "crop_orientation": {"value": 0, "confidence": 1}, }, diff --git a/demo/backend/pytorch.py b/demo/backend/pytorch.py index 548d696dde..ece5fb48fa 100644 --- a/demo/backend/pytorch.py +++ b/demo/backend/pytorch.py @@ -47,7 +47,6 @@ def load_predictor( """Load a predictor from doctr.models Args: - ---- det_arch: detection architecture reco_arch: recognition architecture assume_straight_pages: whether to assume straight pages or not @@ -60,7 +59,6 @@ def load_predictor( device: torch.device, the device to load the predictor on Returns: - ------- instance of OCRPredictor """ predictor = ocr_predictor( @@ -83,13 +81,11 @@ def forward_image(predictor: OCRPredictor, image: np.ndarray, device: torch.devi """Forward an image through the predictor Args: - ---- predictor: instance of OCRPredictor image: image to process device: torch.device, the device to process the image on Returns: - ------- segmentation map """ with torch.no_grad(): diff --git a/demo/backend/tensorflow.py b/demo/backend/tensorflow.py index 9fecfce3bc..82a569fd5f 100644 --- a/demo/backend/tensorflow.py +++ b/demo/backend/tensorflow.py @@ -46,7 +46,6 @@ def load_predictor( """Load a predictor from doctr.models Args: - ---- det_arch: detection architecture reco_arch: recognition architecture assume_straight_pages: whether to assume straight pages or not @@ -59,7 +58,6 @@ def load_predictor( device: tf.device, the device to load the predictor on Returns: - ------- instance of OCRPredictor """ with device: @@ -83,13 +81,11 @@ def forward_image(predictor: OCRPredictor, image: np.ndarray, device: tf.device) """Forward an image through the predictor Args: - ---- predictor: instance of OCRPredictor image: image to process as numpy array device: tf.device, the device to process the image on Returns: - ------- segmentation map """ with device: diff --git a/docs/source/conf.py b/docs/source/conf.py index 43a072df99..40a9a87f4f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -28,7 +28,7 @@ project = "docTR" _copyright_str = f"-{datetime.now().year}" if datetime.now().year > 2021 else "" copyright = f"2021{_copyright_str}, Mindee" -author = "François-Guillaume Fernandez, Charles Gaillard" +author = "François-Guillaume Fernandez, Charles Gaillard, Olivier Dulcy, Felix Dittrich" # The full version, including alpha/beta/rc tags version = doctr.__version__ diff --git a/docs/source/getting_started/installing.rst b/docs/source/getting_started/installing.rst index 39e79aa3dd..db948f506e 100644 --- a/docs/source/getting_started/installing.rst +++ b/docs/source/getting_started/installing.rst @@ -33,14 +33,6 @@ We strive towards reducing framework-specific dependencies to a minimum, but som .. tabs:: - .. tab:: TensorFlow - - .. code:: bash - - pip install "python-doctr[tf]" - # or with preinstalled packages for visualization & html & contrib module support - pip install "python-doctr[tf,viz,html,contib]" - .. tab:: PyTorch .. code:: bash @@ -49,8 +41,13 @@ We strive towards reducing framework-specific dependencies to a minimum, but som # or with preinstalled packages for visualization & html & contrib module support pip install "python-doctr[torch,viz,html,contrib]" + .. tab:: TensorFlow + .. code:: bash + pip install "python-doctr[tf]" + # or with preinstalled packages for visualization & html & contrib module support + pip install "python-doctr[tf,viz,html,contib]" Via Conda (Only for Linux) ========================== @@ -70,16 +67,16 @@ Install the library in developer mode: .. tabs:: - .. tab:: TensorFlow + .. tab:: PyTorch .. code:: bash git clone https://github.com/mindee/doctr.git - pip install -e doctr/.[tf] + pip install -e doctr/.[torch] - .. tab:: PyTorch + .. tab:: TensorFlow .. code:: bash git clone https://github.com/mindee/doctr.git - pip install -e doctr/.[torch] + pip install -e doctr/.[tf] diff --git a/docs/source/modules/datasets.rst b/docs/source/modules/datasets.rst index 872212a121..4d40443166 100644 --- a/docs/source/modules/datasets.rst +++ b/docs/source/modules/datasets.rst @@ -52,11 +52,22 @@ Custom dataset loader .. autoclass:: OCRDataset -Dataloader ---------------------- +Dataset utils +------------- + +.. autofunction:: translate + +.. autofunction:: encode_string -.. autoclass:: doctr.datasets.loader.DataLoader +.. autofunction:: decode_sequence + +.. autofunction:: encode_sequences +.. autofunction:: pre_transform_multiclass + +.. autofunction:: crop_bboxes_from_image + +.. autofunction:: convert_target_to_relative .. _vocabs: @@ -172,5 +183,3 @@ of vocabs. * - multilingual - 195 - english & french & german & italian & spanish & portuguese & czech & polish & dutch & norwegian & danish & finnish & swedish & § - -.. autofunction:: encode_sequences diff --git a/docs/source/modules/transforms.rst b/docs/source/modules/transforms.rst index 7fc02f4cc4..d23fc5b7a7 100644 --- a/docs/source/modules/transforms.rst +++ b/docs/source/modules/transforms.rst @@ -10,22 +10,11 @@ Supported transformations ------------------------- Here are all transformations that are available through docTR: +.. currentmodule:: doctr.transforms.modules + .. autoclass:: Resize -.. autoclass:: Normalize -.. autoclass:: LambdaTransformation -.. autoclass:: ToGray -.. autoclass:: ColorInversion -.. autoclass:: RandomBrightness -.. autoclass:: RandomContrast -.. autoclass:: RandomSaturation -.. autoclass:: RandomHue -.. autoclass:: RandomGamma -.. autoclass:: RandomJpegQuality -.. autoclass:: RandomRotate -.. autoclass:: RandomCrop -.. autoclass:: GaussianBlur -.. autoclass:: ChannelShuffle .. autoclass:: GaussianNoise +.. autoclass:: ChannelShuffle .. autoclass:: RandomHorizontalFlip .. autoclass:: RandomShadow .. autoclass:: RandomResize @@ -35,6 +24,10 @@ Composing transformations --------------------------------------------- It is common to require several transformations to be performed consecutively. -.. autoclass:: Compose +.. autoclass:: SampleCompose +.. autoclass:: ImageTransform +.. autoclass:: ColorInversion .. autoclass:: OneOf .. autoclass:: RandomApply +.. autoclass:: RandomRotate +.. autoclass:: RandomCrop diff --git a/docs/source/modules/utils.rst b/docs/source/modules/utils.rst index ac0b13d9df..c4b99f356b 100644 --- a/docs/source/modules/utils.rst +++ b/docs/source/modules/utils.rst @@ -14,6 +14,11 @@ Easy-to-use functions to make sense of your model's predictions. .. autofunction:: visualize_page +Reconstitution +--------------- + +.. currentmodule:: doctr.utils.reconstitution + .. autofunction:: synthesize_page diff --git a/docs/source/using_doctr/custom_models_training.rst b/docs/source/using_doctr/custom_models_training.rst index 13e4640a36..c70eeb7c44 100644 --- a/docs/source/using_doctr/custom_models_training.rst +++ b/docs/source/using_doctr/custom_models_training.rst @@ -14,84 +14,84 @@ This section shows how you can easily load a custom trained model in docTR. .. tabs:: - .. tab:: TensorFlow + .. tab:: PyTorch .. code:: python3 + import torch from doctr.models import ocr_predictor, db_resnet50, crnn_vgg16_bn # Load custom detection model det_model = db_resnet50(pretrained=False, pretrained_backbone=False) - det_model.load_weights("") + det_params = torch.load('', map_location="cpu") + det_model.load_state_dict(det_params) predictor = ocr_predictor(det_arch=det_model, reco_arch="vitstr_small", pretrained=True) # Load custom recognition model reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False) - reco_model.load_weights("") + reco_params = torch.load('', map_location="cpu") + reco_model.load_state_dict(reco_params) predictor = ocr_predictor(det_arch="linknet_resnet18", reco_arch=reco_model, pretrained=True) # Load custom detection and recognition model det_model = db_resnet50(pretrained=False, pretrained_backbone=False) - det_model.load_weights("") + det_params = torch.load('', map_location="cpu") + det_model.load_state_dict(det_params) reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False) - reco_model.load_weights("") + reco_params = torch.load('', map_location="cpu") + reco_model.load_state_dict(reco_params) predictor = ocr_predictor(det_arch=det_model, reco_arch=reco_model, pretrained=False) - .. tab:: PyTorch + .. tab:: TensorFlow .. code:: python3 - import torch from doctr.models import ocr_predictor, db_resnet50, crnn_vgg16_bn # Load custom detection model det_model = db_resnet50(pretrained=False, pretrained_backbone=False) - det_params = torch.load('', map_location="cpu") - det_model.load_state_dict(det_params) + det_model.load_weights("") predictor = ocr_predictor(det_arch=det_model, reco_arch="vitstr_small", pretrained=True) # Load custom recognition model reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False) - reco_params = torch.load('', map_location="cpu") - reco_model.load_state_dict(reco_params) + reco_model.load_weights("") predictor = ocr_predictor(det_arch="linknet_resnet18", reco_arch=reco_model, pretrained=True) # Load custom detection and recognition model det_model = db_resnet50(pretrained=False, pretrained_backbone=False) - det_params = torch.load('', map_location="cpu") - det_model.load_state_dict(det_params) + det_model.load_weights("") reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False) - reco_params = torch.load('', map_location="cpu") - reco_model.load_state_dict(reco_params) + reco_model.load_weights("") predictor = ocr_predictor(det_arch=det_model, reco_arch=reco_model, pretrained=False) Load a custom recognition model trained on another vocabulary as the default one (French): .. tabs:: - .. tab:: TensorFlow + .. tab:: PyTorch .. code:: python3 + import torch from doctr.models import ocr_predictor, crnn_vgg16_bn from doctr.datasets import VOCABS reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False, vocab=VOCABS["german"]) - reco_model.load_weights("") + reco_params = torch.load('', map_location="cpu") + reco_model.load_state_dict(reco_params) predictor = ocr_predictor(det_arch='linknet_resnet18', reco_arch=reco_model, pretrained=True) - .. tab:: PyTorch + .. tab:: TensorFlow .. code:: python3 - import torch from doctr.models import ocr_predictor, crnn_vgg16_bn from doctr.datasets import VOCABS reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False, vocab=VOCABS["german"]) - reco_params = torch.load('', map_location="cpu") - reco_model.load_state_dict(reco_params) + reco_model.load_weights("") predictor = ocr_predictor(det_arch='linknet_resnet18', reco_arch=reco_model, pretrained=True) @@ -99,36 +99,37 @@ Load a custom trained KIE detection model: .. tabs:: - .. tab:: TensorFlow + .. tab:: PyTorch .. code:: python3 + import torch from doctr.models import kie_predictor, db_resnet50 det_model = db_resnet50(pretrained=False, pretrained_backbone=False, class_names=['total', 'date']) - det_model.load_weights("") + det_params = torch.load('', map_location="cpu") + det_model.load_state_dict(det_params) kie_predictor(det_arch=det_model, reco_arch='crnn_vgg16_bn', pretrained=True) - .. tab:: PyTorch + .. tab:: TensorFlow .. code:: python3 - import torch from doctr.models import kie_predictor, db_resnet50 det_model = db_resnet50(pretrained=False, pretrained_backbone=False, class_names=['total', 'date']) - det_params = torch.load('', map_location="cpu") - det_model.load_state_dict(det_params) + det_model.load_weights("") kie_predictor(det_arch=det_model, reco_arch='crnn_vgg16_bn', pretrained=True) Load a model with customized Preprocessor: .. tabs:: - .. tab:: TensorFlow + .. tab:: PyTorch .. code:: python3 + import torch from doctr.models.predictor import OCRPredictor from doctr.models.detection.predictor import DetectionPredictor from doctr.models.recognition.predictor import RecognitionPredictor @@ -136,9 +137,11 @@ Load a model with customized Preprocessor: from doctr.models import db_resnet50, crnn_vgg16_bn det_model = db_resnet50(pretrained=False, pretrained_backbone=False) - det_model.load_weights("") + det_params = torch.load('', map_location="cpu") + det_model.load_state_dict(det_params) reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False) - reco_model.load_weights("") + reco_params = torch.load(, map_location="cpu") + reco_model.load_state_dict(reco_params) det_predictor = DetectionPredictor( PreProcessor( @@ -163,11 +166,10 @@ Load a model with customized Preprocessor: predictor = OCRPredictor(det_predictor, reco_predictor) - .. tab:: PyTorch + .. tab:: TensorFlow .. code:: python3 - import torch from doctr.models.predictor import OCRPredictor from doctr.models.detection.predictor import DetectionPredictor from doctr.models.recognition.predictor import RecognitionPredictor @@ -175,11 +177,9 @@ Load a model with customized Preprocessor: from doctr.models import db_resnet50, crnn_vgg16_bn det_model = db_resnet50(pretrained=False, pretrained_backbone=False) - det_params = torch.load('', map_location="cpu") - det_model.load_state_dict(det_params) + det_model.load_weights("") reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False) - reco_params = torch.load(, map_location="cpu") - reco_model.load_state_dict(reco_params) + reco_model.load_weights("") det_predictor = DetectionPredictor( PreProcessor( @@ -224,18 +224,21 @@ Loading your custom trained orientation classification model .. tabs:: - .. tab:: TensorFlow + .. tab:: PyTorch .. code:: python3 + import torch from doctr.io import DocumentFile from doctr.models import ocr_predictor, mobilenet_v3_small_page_orientation, mobilenet_v3_small_crop_orientation from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=False) - custom_page_orientation_model.load_weights("") + page_params = torch.load('', map_location="cpu") + custom_page_orientation_model.load_state_dict(page_params) custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=False) - custom_crop_orientation_model.load_weights("") + crop_params = torch.load('', map_location="cpu") + custom_crop_orientation_model.load_state_dict(crop_params) predictor = ocr_predictor( pretrained=True, @@ -248,21 +251,18 @@ Loading your custom trained orientation classification model predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) - .. tab:: PyTorch + .. tab:: TensorFlow .. code:: python3 - import torch from doctr.io import DocumentFile from doctr.models import ocr_predictor, mobilenet_v3_small_page_orientation, mobilenet_v3_small_crop_orientation from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=False) - page_params = torch.load('', map_location="cpu") - custom_page_orientation_model.load_state_dict(page_params) + custom_page_orientation_model.load_weights("") custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=False) - crop_params = torch.load('', map_location="cpu") - custom_crop_orientation_model.load_state_dict(crop_params) + custom_crop_orientation_model.load_weights("") predictor = ocr_predictor( pretrained=True, diff --git a/docs/source/using_doctr/sharing_models.rst b/docs/source/using_doctr/sharing_models.rst index 7ff09f08f3..3c633baf9d 100644 --- a/docs/source/using_doctr/sharing_models.rst +++ b/docs/source/using_doctr/sharing_models.rst @@ -14,7 +14,7 @@ This section shows how you can easily load a pretrained model from the Huggingfa .. tabs:: - .. tab:: TensorFlow + .. tab:: PyTorch .. code:: python3 @@ -22,14 +22,14 @@ This section shows how you can easily load a pretrained model from the Huggingfa from doctr.models import ocr_predictor, from_hub image = DocumentFile.from_images(['data/example.jpg']) # Load a custom detection model from huggingface hub - det_model = from_hub('Felix92/doctr-tf-db-resnet50') + det_model = from_hub('Felix92/doctr-torch-db-mobilenet-v3-large') # Load a custom recognition model from huggingface hub - reco_model = from_hub('Felix92/doctr-tf-crnn-vgg16-bn-french') + reco_model = from_hub('Felix92/doctr-torch-crnn-mobilenet-v3-large-french') # You can easily plug in this models to the OCR predictor predictor = ocr_predictor(det_arch=det_model, reco_arch=reco_model) result = predictor(image) - .. tab:: PyTorch + .. tab:: TensorFlow .. code:: python3 @@ -37,9 +37,9 @@ This section shows how you can easily load a pretrained model from the Huggingfa from doctr.models import ocr_predictor, from_hub image = DocumentFile.from_images(['data/example.jpg']) # Load a custom detection model from huggingface hub - det_model = from_hub('Felix92/doctr-torch-db-mobilenet-v3-large') + det_model = from_hub('Felix92/doctr-tf-db-resnet50') # Load a custom recognition model from huggingface hub - reco_model = from_hub('Felix92/doctr-torch-crnn-mobilenet-v3-large-french') + reco_model = from_hub('Felix92/doctr-tf-crnn-vgg16-bn-french') # You can easily plug in this models to the OCR predictor predictor = ocr_predictor(det_arch=det_model, reco_arch=reco_model) result = predictor(image) @@ -67,17 +67,17 @@ It is also possible to push your model directly after training. .. tabs:: - .. tab:: TensorFlow + .. tab:: PyTorch .. code:: bash - python3 ~/doctr/references/recognition/train_tensorflow.py crnn_mobilenet_v3_large --name doctr-crnn-mobilenet-v3-large --push-to-hub + python3 ~/doctr/references/recognition/train_pytorch.py crnn_mobilenet_v3_large --name doctr-crnn-mobilenet-v3-large --push-to-hub - .. tab:: PyTorch + .. tab:: TensorFlow .. code:: bash - python3 ~/doctr/references/recognition/train_pytorch.py crnn_mobilenet_v3_large --name doctr-crnn-mobilenet-v3-large --push-to-hub + python3 ~/doctr/references/recognition/train_tensorflow.py crnn_mobilenet_v3_large --name doctr-crnn-mobilenet-v3-large --push-to-hub Pretrained community models diff --git a/docs/source/using_doctr/using_model_export.rst b/docs/source/using_doctr/using_model_export.rst index c62c36169b..db632701ba 100644 --- a/docs/source/using_doctr/using_model_export.rst +++ b/docs/source/using_doctr/using_model_export.rst @@ -26,6 +26,14 @@ Advantages: .. tabs:: + .. tab:: PyTorch + + .. code:: python3 + + import torch + predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True).cuda().half() + res = predictor(doc) + .. tab:: TensorFlow .. code:: python3 @@ -35,14 +43,6 @@ Advantages: mixed_precision.set_global_policy('mixed_float16') predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True) - .. tab:: PyTorch - - .. code:: python3 - - import torch - predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True).cuda().half() - res = predictor(doc) - Export to ONNX ^^^^^^^^^^^^^^ @@ -52,34 +52,33 @@ It defines a common format for representing models, including the network struct .. tabs:: - .. tab:: TensorFlow + .. tab:: PyTorch .. code:: python3 - import tensorflow as tf + import torch from doctr.models import vitstr_small from doctr.models.utils import export_model_to_onnx batch_size = 16 input_shape = (3, 32, 128) model = vitstr_small(pretrained=True, exportable=True) - dummy_input = [tf.TensorSpec([batch_size, input_shape], tf.float32, name="input")] - model_path, output = export_model_to_onnx(model, model_name="vitstr.onnx", dummy_input=dummy_input) - + dummy_input = torch.rand((batch_size, input_shape), dtype=torch.float32) + model_path = export_model_to_onnx(model, model_name="vitstr.onnx, dummy_input=dummy_input) - .. tab:: PyTorch + .. tab:: TensorFlow .. code:: python3 - import torch + import tensorflow as tf from doctr.models import vitstr_small from doctr.models.utils import export_model_to_onnx batch_size = 16 input_shape = (32, 128, 3) model = vitstr_small(pretrained=True, exportable=True) - dummy_input = torch.rand((batch_size, input_shape), dtype=torch.float32) - model_path = export_model_to_onnx(model, model_name="vitstr.onnx, dummy_input=dummy_input) + dummy_input = [tf.TensorSpec([batch_size, input_shape], tf.float32, name="input")] + model_path, output = export_model_to_onnx(model, model_name="vitstr.onnx", dummy_input=dummy_input) Using your ONNX exported model diff --git a/doctr/contrib/__init__.py b/doctr/contrib/__init__.py index e69de29bb2..dd46199ccc 100644 --- a/doctr/contrib/__init__.py +++ b/doctr/contrib/__init__.py @@ -0,0 +1 @@ +from .artefacts import ArtefactDetector diff --git a/doctr/contrib/artefacts.py b/doctr/contrib/artefacts.py index 646e199186..cbc819e568 100644 --- a/doctr/contrib/artefacts.py +++ b/doctr/contrib/artefacts.py @@ -34,7 +34,6 @@ class ArtefactDetector(_BasePredictor): >>> results = detector(doc) Args: - ---- arch: the architecture to use batch_size: the batch size to use model_path: the path to the model to use @@ -109,7 +108,6 @@ def show(self, **kwargs: Any) -> None: Display the results Args: - ---- **kwargs: additional keyword arguments to be passed to `plt.show` """ requires_package("matplotlib", "`.show()` requires matplotlib installed") diff --git a/doctr/contrib/base.py b/doctr/contrib/base.py index 4b65834383..806b109d43 100644 --- a/doctr/contrib/base.py +++ b/doctr/contrib/base.py @@ -16,7 +16,6 @@ class _BasePredictor: Base class for all predictors Args: - ---- batch_size: the batch size to use url: the url to use to download a model if needed model_path: the path to the model to use @@ -35,13 +34,11 @@ def _init_model(self, url: Optional[str] = None, model_path: Optional[str] = Non Download the model from the given url if needed Args: - ---- url: the url to use model_path: the path to the model to use **kwargs: additional arguments to be passed to `download_from_url` Returns: - ------- Any: the ONNX loaded model """ requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.") @@ -57,11 +54,9 @@ def preprocess(self, img: np.ndarray) -> np.ndarray: Preprocess the input image Args: - ---- img: the input image to preprocess Returns: - ------- np.ndarray: the preprocessed image """ raise NotImplementedError @@ -71,12 +66,10 @@ def postprocess(self, output: List[np.ndarray], input_images: List[List[np.ndarr Postprocess the model output Args: - ---- output: the model output to postprocess input_images: the input images used to generate the output Returns: - ------- Any: the postprocessed output """ raise NotImplementedError @@ -86,11 +79,9 @@ def __call__(self, inputs: List[np.ndarray]) -> Any: Call the model on the given inputs Args: - ---- inputs: the inputs to use Returns: - ------- Any: the postprocessed output """ self._inputs = inputs diff --git a/doctr/datasets/cord.py b/doctr/datasets/cord.py index 9e2188727d..244d16d9a2 100644 --- a/doctr/datasets/cord.py +++ b/doctr/datasets/cord.py @@ -29,7 +29,6 @@ class CORD(VisionDataset): >>> img, target = train_set[0] Args: - ---- train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task diff --git a/doctr/datasets/datasets/__init__.py b/doctr/datasets/datasets/__init__.py index c7110f5669..5bc3a74b47 100644 --- a/doctr/datasets/datasets/__init__.py +++ b/doctr/datasets/datasets/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): - from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): + from .tensorflow import * # type: ignore[assignment] diff --git a/doctr/datasets/datasets/base.py b/doctr/datasets/datasets/base.py index 58f1ca29f6..61ed53eae8 100644 --- a/doctr/datasets/datasets/base.py +++ b/doctr/datasets/datasets/base.py @@ -82,7 +82,6 @@ class _VisionDataset(_AbstractDataset): """Implements an abstract dataset Args: - ---- url: URL of the dataset file_name: name of the file once downloaded file_hash: expected SHA256 of the file diff --git a/doctr/datasets/detection.py b/doctr/datasets/detection.py index 0000704dfa..a023b2f9a0 100644 --- a/doctr/datasets/detection.py +++ b/doctr/datasets/detection.py @@ -26,7 +26,6 @@ class DetectionDataset(AbstractDataset): >>> img, target = train_set[0] Args: - ---- img_folder: folder with all the images of the dataset label_path: path to the annotations of each image use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) @@ -70,13 +69,11 @@ def format_polygons( """Format polygons into an array Args: - ---- polygons: the bounding boxes use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) np_dtype: dtype of array Returns: - ------- geoms: bounding boxes as np array polygons_classes: list of classes for each bounding box """ diff --git a/doctr/datasets/doc_artefacts.py b/doctr/datasets/doc_artefacts.py index 6a05a01150..5830f89f33 100644 --- a/doctr/datasets/doc_artefacts.py +++ b/doctr/datasets/doc_artefacts.py @@ -26,7 +26,6 @@ class DocArtefacts(VisionDataset): >>> img, target = train_set[0] Args: - ---- train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) **kwargs: keyword arguments from `VisionDataset`. diff --git a/doctr/datasets/funsd.py b/doctr/datasets/funsd.py index 3bd8b088f9..4529d0a18f 100644 --- a/doctr/datasets/funsd.py +++ b/doctr/datasets/funsd.py @@ -29,7 +29,6 @@ class FUNSD(VisionDataset): >>> img, target = train_set[0] Args: - ---- train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task diff --git a/doctr/datasets/generator/__init__.py b/doctr/datasets/generator/__init__.py index c7110f5669..5bc3a74b47 100644 --- a/doctr/datasets/generator/__init__.py +++ b/doctr/datasets/generator/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): - from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): + from .tensorflow import * # type: ignore[assignment] diff --git a/doctr/datasets/generator/base.py b/doctr/datasets/generator/base.py index 424f59563d..2b868b26ee 100644 --- a/doctr/datasets/generator/base.py +++ b/doctr/datasets/generator/base.py @@ -24,7 +24,6 @@ def synthesize_text_img( """Generate a synthetic text image Args: - ---- text: the text to render as an image font_size: the size of the font font_family: the font family (has to be installed on your system) @@ -32,7 +31,6 @@ def synthesize_text_img( text_color: text color on the final image Returns: - ------- PIL image of the text """ background_color = (0, 0, 0) if background_color is None else background_color diff --git a/doctr/datasets/generator/pytorch.py b/doctr/datasets/generator/pytorch.py index b254c91e4a..d3dbecbd75 100644 --- a/doctr/datasets/generator/pytorch.py +++ b/doctr/datasets/generator/pytorch.py @@ -18,7 +18,6 @@ class CharacterGenerator(_CharacterGenerator): >>> img, target = ds[0] Args: - ---- vocab: vocabulary to take the character from num_samples: number of samples that will be generated iterating over the dataset cache_samples: whether generated images should be cached firsthand @@ -40,7 +39,6 @@ class WordGenerator(_WordGenerator): >>> img, target = ds[0] Args: - ---- vocab: vocabulary to take the character from min_chars: minimum number of characters in a word max_chars: maximum number of characters in a word diff --git a/doctr/datasets/generator/tensorflow.py b/doctr/datasets/generator/tensorflow.py index 82e205e038..a71cb9bb5d 100644 --- a/doctr/datasets/generator/tensorflow.py +++ b/doctr/datasets/generator/tensorflow.py @@ -18,7 +18,6 @@ class CharacterGenerator(_CharacterGenerator): >>> img, target = ds[0] Args: - ---- vocab: vocabulary to take the character from num_samples: number of samples that will be generated iterating over the dataset cache_samples: whether generated images should be cached firsthand @@ -46,7 +45,6 @@ class WordGenerator(_WordGenerator): >>> img, target = ds[0] Args: - ---- vocab: vocabulary to take the character from min_chars: minimum number of characters in a word max_chars: maximum number of characters in a word diff --git a/doctr/datasets/ic03.py b/doctr/datasets/ic03.py index b3af8d958c..50920952b5 100644 --- a/doctr/datasets/ic03.py +++ b/doctr/datasets/ic03.py @@ -28,7 +28,6 @@ class IC03(VisionDataset): >>> img, target = train_set[0] Args: - ---- train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task diff --git a/doctr/datasets/ic13.py b/doctr/datasets/ic13.py index 0082d92316..725b665758 100644 --- a/doctr/datasets/ic13.py +++ b/doctr/datasets/ic13.py @@ -33,7 +33,6 @@ class IC13(AbstractDataset): >>> img, target = test_set[0] Args: - ---- img_folder: folder with all the images of the dataset label_folder: folder with all annotation files for the images use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) diff --git a/doctr/datasets/iiit5k.py b/doctr/datasets/iiit5k.py index 89619dd8aa..a87d454b42 100644 --- a/doctr/datasets/iiit5k.py +++ b/doctr/datasets/iiit5k.py @@ -30,7 +30,6 @@ class IIIT5K(VisionDataset): >>> img, target = train_set[0] Args: - ---- train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task diff --git a/doctr/datasets/iiithws.py b/doctr/datasets/iiithws.py index e33e3acd53..0066b9a489 100644 --- a/doctr/datasets/iiithws.py +++ b/doctr/datasets/iiithws.py @@ -32,7 +32,6 @@ class IIITHWS(AbstractDataset): >>> img, target = test_set[0] Args: - ---- img_folder: folder with all the images of the dataset label_path: path to the file with the labels train: whether the subset should be the training one diff --git a/doctr/datasets/imgur5k.py b/doctr/datasets/imgur5k.py index 4dcfec02b8..b99d8b4152 100644 --- a/doctr/datasets/imgur5k.py +++ b/doctr/datasets/imgur5k.py @@ -40,7 +40,6 @@ class IMGUR5K(AbstractDataset): >>> img, target = test_set[0] Args: - ---- img_folder: folder with all the images of the dataset label_path: path to the annotations file of the dataset train: whether the subset should be the training one diff --git a/doctr/datasets/loader.py b/doctr/datasets/loader.py index c8c584df85..583c1ee50d 100644 --- a/doctr/datasets/loader.py +++ b/doctr/datasets/loader.py @@ -16,11 +16,9 @@ def default_collate(samples): """Collate multiple elements into batches Args: - ---- samples: list of N tuples containing M elements Returns: - ------- Tuple of M sequences contianing N elements each """ batch_data = zip(*samples) @@ -40,7 +38,6 @@ class DataLoader: >>> images, targets = next(train_iter) Args: - ---- dataset: the dataset shuffle: whether the samples should be shuffled before passing it to the iterator batch_size: number of elements in each batch diff --git a/doctr/datasets/mjsynth.py b/doctr/datasets/mjsynth.py index a8b16caebe..650cc01f40 100644 --- a/doctr/datasets/mjsynth.py +++ b/doctr/datasets/mjsynth.py @@ -30,7 +30,6 @@ class MJSynth(AbstractDataset): >>> img, target = test_set[0] Args: - ---- img_folder: folder with all the images of the dataset label_path: path to the file with the labels train: whether the subset should be the training one diff --git a/doctr/datasets/ocr.py b/doctr/datasets/ocr.py index b93c124ce7..69a8471eb7 100644 --- a/doctr/datasets/ocr.py +++ b/doctr/datasets/ocr.py @@ -24,7 +24,6 @@ class OCRDataset(AbstractDataset): >>> img, target = train_set[0] Args: - ---- img_folder: local path to image folder (all jpg at the root) label_file: local path to the label file use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) diff --git a/doctr/datasets/orientation.py b/doctr/datasets/orientation.py index 10bd55444e..2c008240f6 100644 --- a/doctr/datasets/orientation.py +++ b/doctr/datasets/orientation.py @@ -21,7 +21,6 @@ class OrientationDataset(AbstractDataset): >>> img, target = train_set[0] Args: - ---- img_folder: folder with all the images of the dataset **kwargs: keyword arguments from `AbstractDataset`. """ diff --git a/doctr/datasets/recognition.py b/doctr/datasets/recognition.py index ebf37a20ac..b06eb1e264 100644 --- a/doctr/datasets/recognition.py +++ b/doctr/datasets/recognition.py @@ -22,7 +22,6 @@ class RecognitionDataset(AbstractDataset): >>> img, target = train_set[0] Args: - ---- img_folder: path to the images folder labels_path: pathe to the json file containing all labels (character sequences) **kwargs: keyword arguments from `AbstractDataset`. diff --git a/doctr/datasets/sroie.py b/doctr/datasets/sroie.py index d6e7dac83b..83e9e64442 100644 --- a/doctr/datasets/sroie.py +++ b/doctr/datasets/sroie.py @@ -29,7 +29,6 @@ class SROIE(VisionDataset): >>> img, target = train_set[0] Args: - ---- train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task diff --git a/doctr/datasets/svhn.py b/doctr/datasets/svhn.py index 595113a42d..872c77d3c3 100644 --- a/doctr/datasets/svhn.py +++ b/doctr/datasets/svhn.py @@ -28,7 +28,6 @@ class SVHN(VisionDataset): >>> img, target = train_set[0] Args: - ---- train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task diff --git a/doctr/datasets/svt.py b/doctr/datasets/svt.py index b9e88b4cc1..89b6e552bb 100644 --- a/doctr/datasets/svt.py +++ b/doctr/datasets/svt.py @@ -28,7 +28,6 @@ class SVT(VisionDataset): >>> img, target = train_set[0] Args: - ---- train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task diff --git a/doctr/datasets/synthtext.py b/doctr/datasets/synthtext.py index 8be11e2303..f8ceaadfdf 100644 --- a/doctr/datasets/synthtext.py +++ b/doctr/datasets/synthtext.py @@ -31,7 +31,6 @@ class SynthText(VisionDataset): >>> img, target = train_set[0] Args: - ---- train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task diff --git a/doctr/datasets/utils.py b/doctr/datasets/utils.py index 75182a227a..a897faee88 100644 --- a/doctr/datasets/utils.py +++ b/doctr/datasets/utils.py @@ -19,7 +19,15 @@ from .vocabs import VOCABS -__all__ = ["translate", "encode_string", "decode_sequence", "encode_sequences", "pre_transform_multiclass"] +__all__ = [ + "translate", + "encode_string", + "decode_sequence", + "encode_sequences", + "pre_transform_multiclass", + "crop_bboxes_from_image", + "convert_target_to_relative", +] ImageTensor = TypeVar("ImageTensor") @@ -32,13 +40,11 @@ def translate( """Translate a string input in a given vocabulary Args: - ---- input_string: input string to translate vocab_name: vocabulary to use (french, latin, ...) unknown_char: unknown character for non-translatable characters Returns: - ------- A string translated in a given vocab """ if VOCABS.get(vocab_name) is None: @@ -67,12 +73,10 @@ def encode_string( """Given a predefined mapping, encode the string to a sequence of numbers Args: - ---- input_string: string to encode vocab: vocabulary (string), the encoding is given by the indexing of the character sequence Returns: - ------- A list encoding the input_string """ try: @@ -91,12 +95,10 @@ def decode_sequence( """Given a predefined mapping, decode the sequence of numbers to a string Args: - ---- input_seq: array to decode mapping: vocabulary (string), the encoding is given by the indexing of the character sequence Returns: - ------- A string, decoded from input_seq """ if not isinstance(input_seq, (Sequence, np.ndarray)): @@ -119,7 +121,6 @@ def encode_sequences( """Encode character sequences using a given vocab as mapping Args: - ---- sequences: the list of character sequences of size N vocab: the ordered vocab to use for encoding target_size: maximum length of the encoded data @@ -129,7 +130,6 @@ def encode_sequences( dynamic_seq_length: if `target_size` is specified, uses it as upper bound and enables dynamic sequence size Returns: - ------- the padded encoded data as a tensor """ if 0 <= eos < len(vocab): @@ -172,10 +172,19 @@ def encode_sequences( def convert_target_to_relative( img: ImageTensor, target: Union[np.ndarray, Dict[str, Any]] ) -> Tuple[ImageTensor, Union[Dict[str, Any], np.ndarray]]: + """Converts target to relative coordinates + + Args: + img: tf.Tensor or torch.Tensor representing the image + target: target to convert to relative coordinates (boxes (N, 4) or polygons (N, 4, 2)) + + Returns: + The image and the target in relative coordinates + """ if isinstance(target, np.ndarray): - target = convert_to_relative_coords(target, get_img_shape(img)) + target = convert_to_relative_coords(target, get_img_shape(img)) # type: ignore[arg-type] else: - target["boxes"] = convert_to_relative_coords(target["boxes"], get_img_shape(img)) + target["boxes"] = convert_to_relative_coords(target["boxes"], get_img_shape(img)) # type: ignore[arg-type] return img, target @@ -183,12 +192,10 @@ def crop_bboxes_from_image(img_path: Union[str, Path], geoms: np.ndarray) -> Lis """Crop a set of bounding boxes from an image Args: - ---- img_path: path to the image geoms: a array of polygons of shape (N, 4, 2) or of straight boxes of shape (N, 4) Returns: - ------- a list of cropped images """ with Image.open(img_path) as pil_img: @@ -205,12 +212,10 @@ def pre_transform_multiclass(img, target: Tuple[np.ndarray, List]) -> Tuple[np.n """Converts multiclass target to relative coordinates. Args: - ---- img: Image target: tuple of target polygons and their classes names Returns: - ------- Image and dictionary of boxes, with class names as keys """ boxes = convert_to_relative_coords(target[0], get_img_shape(img)) diff --git a/doctr/datasets/wildreceipt.py b/doctr/datasets/wildreceipt.py index 685266931a..f46b5da301 100644 --- a/doctr/datasets/wildreceipt.py +++ b/doctr/datasets/wildreceipt.py @@ -17,9 +17,10 @@ class WILDRECEIPT(AbstractDataset): - """WildReceipt dataset from `"Spatial Dual-Modality Graph Reasoning for Key Information Extraction" - `_ | - `repository `_. + """ + WildReceipt dataset from `"Spatial Dual-Modality Graph Reasoning for Key Information Extraction" + `_ | + `"repository" `_. .. image:: https://doctr-static.mindee.com/models?id=v0.7.0/wildreceipt-dataset.jpg&src=0 :align: center @@ -34,7 +35,6 @@ class WILDRECEIPT(AbstractDataset): >>> img, target = test_set[0] Args: - ---- img_folder: folder with all the images of the dataset label_path: path to the annotations file of the dataset train: whether the subset should be the training one diff --git a/doctr/file_utils.py b/doctr/file_utils.py index 53263345b1..79858a3ed9 100644 --- a/doctr/file_utils.py +++ b/doctr/file_utils.py @@ -98,7 +98,6 @@ def requires_package(name: str, extra_message: Optional[str] = None) -> None: # package requirement helper Args: - ---- name: name of the package extra_message: additional message to display if the package is not found """ diff --git a/doctr/io/elements.py b/doctr/io/elements.py index 324d70f0b4..f846b197fb 100644 --- a/doctr/io/elements.py +++ b/doctr/io/elements.py @@ -67,7 +67,6 @@ class Word(Element): """Implements a word element Args: - ---- value: the text string of the word confidence: the confidence associated with the text prediction geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to @@ -111,7 +110,6 @@ class Artefact(Element): """Implements a non-textual element Args: - ---- artefact_type: the type of artefact confidence: the confidence of the type prediction geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to @@ -144,7 +142,6 @@ class Line(Element): """Implements a line element as a collection of words Args: - ---- words: list of word elements geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to the page's size. If not specified, it will be resolved by default to the smallest bounding box enclosing @@ -202,7 +199,6 @@ class Block(Element): """Implements a block element as a collection of lines and artefacts Args: - ---- lines: list of line elements artefacts: list of artefacts geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to @@ -256,7 +252,6 @@ class Page(Element): """Implements a page element as a collection of blocks Args: - ---- page: image encoded as a numpy array in uint8 blocks: list of block elements page_idx: the index of the page in the input raw document @@ -311,11 +306,9 @@ def synthesize(self, **kwargs) -> np.ndarray: """Synthesize the page from the predictions Args: - ---- **kwargs: keyword arguments passed to the `synthesize_page` method - Returns - ------- + Returns: synthesized page """ return synthesize_page(self.export(), **kwargs) @@ -325,11 +318,9 @@ def export_as_xml(self, file_title: str = "docTR - XML export (hOCR)") -> Tuple[ convention: https://github.com/kba/hocr-spec/blob/master/1.2/spec.md Args: - ---- file_title: the title of the XML file Returns: - ------- a tuple of the XML byte string, and its ElementTree """ p_idx = self.page_idx @@ -437,7 +428,6 @@ class KIEPage(Element): """Implements a KIE page element as a collection of predictions Args: - ---- predictions: Dictionary with list of block elements for each detection class page: image encoded as a numpy array in uint8 page_idx: the index of the page in the input raw document @@ -496,11 +486,9 @@ def synthesize(self, **kwargs) -> np.ndarray: """Synthesize the page from the predictions Args: - ---- **kwargs: keyword arguments passed to the `synthesize_kie_page` method Returns: - ------- synthesized page """ return synthesize_kie_page(self.export(), **kwargs) @@ -510,11 +498,9 @@ def export_as_xml(self, file_title: str = "docTR - XML export (hOCR)") -> Tuple[ convention: https://github.com/kba/hocr-spec/blob/master/1.2/spec.md Args: - ---- file_title: the title of the XML file Returns: - ------- a tuple of the XML byte string, and its ElementTree """ p_idx = self.page_idx @@ -582,7 +568,6 @@ class Document(Element): """Implements a document element as a collection of pages Args: - ---- pages: list of page elements """ @@ -608,11 +593,9 @@ def synthesize(self, **kwargs) -> List[np.ndarray]: """Synthesize all pages from their predictions Args: - ---- **kwargs: keyword arguments passed to the `Page.synthesize` method - Returns - ------- + Returns: list of synthesized pages """ return [page.synthesize(**kwargs) for page in self.pages] @@ -621,11 +604,9 @@ def export_as_xml(self, **kwargs) -> List[Tuple[bytes, ET.ElementTree]]: """Export the document as XML (hOCR-format) Args: - ---- **kwargs: additional keyword arguments passed to the Page.export_as_xml method Returns: - ------- list of tuple of (bytes, ElementTree) """ return [page.export_as_xml(**kwargs) for page in self.pages] @@ -641,7 +622,6 @@ class KIEDocument(Document): """Implements a document element as a collection of pages Args: - ---- pages: list of page elements """ diff --git a/doctr/io/html.py b/doctr/io/html.py index f8a8da237d..21c0340a88 100644 --- a/doctr/io/html.py +++ b/doctr/io/html.py @@ -15,12 +15,10 @@ def read_html(url: str, **kwargs: Any) -> bytes: >>> doc = read_html("https://www.yoursite.com") Args: - ---- url: URL of the target web page **kwargs: keyword arguments from `weasyprint.HTML` Returns: - ------- decoded PDF file as a bytes stream """ from weasyprint import HTML diff --git a/doctr/io/image/__init__.py b/doctr/io/image/__init__.py index 393c70c359..a8fde5fef9 100644 --- a/doctr/io/image/__init__.py +++ b/doctr/io/image/__init__.py @@ -2,7 +2,7 @@ from .base import * -if is_tf_available(): - from .tensorflow import * -elif is_torch_available(): +if is_torch_available(): from .pytorch import * +elif is_tf_available(): + from .tensorflow import * diff --git a/doctr/io/image/base.py b/doctr/io/image/base.py index b4c2ed3065..c11caba034 100644 --- a/doctr/io/image/base.py +++ b/doctr/io/image/base.py @@ -25,13 +25,11 @@ def read_img_as_numpy( >>> page = read_img_as_numpy("path/to/your/doc.jpg") Args: - ---- file: the path to the image file output_size: the expected output size of each page in format H x W rgb_output: whether the output ndarray channel order should be RGB instead of BGR. Returns: - ------- the page decoded as numpy ndarray of shape H x W x 3 """ if isinstance(file, (str, Path)): diff --git a/doctr/io/image/pytorch.py b/doctr/io/image/pytorch.py index 67d37e46b3..26167f81f5 100644 --- a/doctr/io/image/pytorch.py +++ b/doctr/io/image/pytorch.py @@ -20,12 +20,10 @@ def tensor_from_pil(pil_img: Image.Image, dtype: torch.dtype = torch.float32) -> """Convert a PIL Image to a PyTorch tensor Args: - ---- pil_img: a PIL image dtype: the output tensor data type Returns: - ------- decoded image as tensor """ if dtype == torch.float32: @@ -40,12 +38,10 @@ def read_img_as_tensor(img_path: AbstractPath, dtype: torch.dtype = torch.float3 """Read an image file as a PyTorch tensor Args: - ---- img_path: location of the image file dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. Returns: - ------- decoded image as a tensor """ if dtype not in (torch.uint8, torch.float16, torch.float32): @@ -59,12 +55,10 @@ def decode_img_as_tensor(img_content: bytes, dtype: torch.dtype = torch.float32) """Read a byte stream as a PyTorch tensor Args: - ---- img_content: bytes of a decoded image dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. Returns: - ------- decoded image as a tensor """ if dtype not in (torch.uint8, torch.float16, torch.float32): @@ -78,12 +72,10 @@ def tensor_from_numpy(npy_img: np.ndarray, dtype: torch.dtype = torch.float32) - """Read an image file as a PyTorch tensor Args: - ---- npy_img: image encoded as a numpy array of shape (H, W, C) in np.uint8 dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. Returns: - ------- same image as a tensor of shape (C, H, W) """ if dtype not in (torch.uint8, torch.float16, torch.float32): diff --git a/doctr/io/image/tensorflow.py b/doctr/io/image/tensorflow.py index 28fb2fadd5..2b4435abc7 100644 --- a/doctr/io/image/tensorflow.py +++ b/doctr/io/image/tensorflow.py @@ -19,12 +19,10 @@ def tensor_from_pil(pil_img: Image.Image, dtype: tf.dtypes.DType = tf.float32) - """Convert a PIL Image to a TensorFlow tensor Args: - ---- pil_img: a PIL image dtype: the output tensor data type Returns: - ------- decoded image as tensor """ npy_img = img_to_array(pil_img) @@ -36,12 +34,10 @@ def read_img_as_tensor(img_path: AbstractPath, dtype: tf.dtypes.DType = tf.float """Read an image file as a TensorFlow tensor Args: - ---- img_path: location of the image file dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. Returns: - ------- decoded image as a tensor """ if dtype not in (tf.uint8, tf.float16, tf.float32): @@ -61,12 +57,10 @@ def decode_img_as_tensor(img_content: bytes, dtype: tf.dtypes.DType = tf.float32 """Read a byte stream as a TensorFlow tensor Args: - ---- img_content: bytes of a decoded image dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. Returns: - ------- decoded image as a tensor """ if dtype not in (tf.uint8, tf.float16, tf.float32): @@ -85,12 +79,10 @@ def tensor_from_numpy(npy_img: np.ndarray, dtype: tf.dtypes.DType = tf.float32) """Read an image file as a TensorFlow tensor Args: - ---- npy_img: image encoded as a numpy array of shape (H, W, C) in np.uint8 dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. Returns: - ------- same image as a tensor of shape (H, W, C) """ if dtype not in (tf.uint8, tf.float16, tf.float32): diff --git a/doctr/io/pdf.py b/doctr/io/pdf.py index 4c72b7438a..51545f07c0 100644 --- a/doctr/io/pdf.py +++ b/doctr/io/pdf.py @@ -26,7 +26,6 @@ def read_pdf( >>> doc = read_pdf("path/to/your/doc.pdf") Args: - ---- file: the path to the PDF file scale: rendering scale (1 corresponds to 72dpi) rgb_mode: if True, the output will be RGB, otherwise BGR @@ -34,7 +33,6 @@ def read_pdf( **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render` Returns: - ------- the list of pages decoded as numpy ndarray of shape H x W x C """ # Rasterise pages to numpy ndarrays with pypdfium2 diff --git a/doctr/io/reader.py b/doctr/io/reader.py index 76f7317cb1..cc969ff48a 100644 --- a/doctr/io/reader.py +++ b/doctr/io/reader.py @@ -29,12 +29,10 @@ def from_pdf(cls, file: AbstractFile, **kwargs) -> List[np.ndarray]: >>> doc = DocumentFile.from_pdf("path/to/your/doc.pdf") Args: - ---- file: the path to the PDF file or a binary stream **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render` Returns: - ------- the list of pages decoded as numpy ndarray of shape H x W x 3 """ return read_pdf(file, **kwargs) @@ -47,12 +45,10 @@ def from_url(cls, url: str, **kwargs) -> List[np.ndarray]: >>> doc = DocumentFile.from_url("https://www.yoursite.com") Args: - ---- url: the URL of the target web page **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render` Returns: - ------- the list of pages decoded as numpy ndarray of shape H x W x 3 """ requires_package( @@ -71,12 +67,10 @@ def from_images(cls, files: Union[Sequence[AbstractFile], AbstractFile], **kwarg >>> pages = DocumentFile.from_images(["path/to/your/page1.png", "path/to/your/page2.png"]) Args: - ---- files: the path to the image file or a binary stream, or a collection of those **kwargs: additional parameters to :meth:`doctr.io.image.read_img_as_numpy` Returns: - ------- the list of pages decoded as numpy ndarray of shape H x W x 3 """ if isinstance(files, (str, Path, bytes)): diff --git a/doctr/models/_utils.py b/doctr/models/_utils.py index ab1922c905..9fa7638e97 100644 --- a/doctr/models/_utils.py +++ b/doctr/models/_utils.py @@ -20,11 +20,9 @@ def get_max_width_length_ratio(contour: np.ndarray) -> float: """Get the maximum shape ratio of a contour. Args: - ---- contour: the contour from cv2.findContour Returns: - ------- the maximum shape ratio """ _, (w, h), _ = cv2.minAreaRect(contour) @@ -43,7 +41,6 @@ def estimate_orientation( lines of the document and the assumption that they should be horizontal. Args: - ---- img: the img or bitmap to analyze (H, W, C) general_page_orientation: the general orientation of the page (angle [0, 90, 180, 270 (-90)], confidence) estimated by a model @@ -53,7 +50,6 @@ def estimate_orientation( lower_area: the minimum area of a contour to be considered Returns: - ------- the estimated angle of the page (clockwise, negative for left side rotation, positive for right side rotation) """ assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported" @@ -162,11 +158,9 @@ def get_language(text: str) -> Tuple[str, float]: Get the language with the highest probability or no language if only a few words or a low probability Args: - ---- text (str): text Returns: - ------- The detected language in ISO 639 code and confidence score """ try: @@ -184,11 +178,9 @@ def invert_data_structure( """Invert a List of Dict of elements to a Dict of list of elements and the other way around Args: - ---- x: a list of dictionaries with the same keys or a dictionary of lists of the same length Returns: - ------- dictionary of list when x is a list of dictionaries or a list of dictionaries when x is dictionary of lists """ if isinstance(x, dict): diff --git a/doctr/models/builder.py b/doctr/models/builder.py index 8dfcafcc9d..ac93d4b2cd 100644 --- a/doctr/models/builder.py +++ b/doctr/models/builder.py @@ -20,7 +20,6 @@ class DocumentBuilder(NestedObject): """Implements a document builder Args: - ---- resolve_lines: whether words should be automatically grouped into lines resolve_blocks: whether lines should be automatically grouped into blocks paragraph_break: relative length of the minimum space separating paragraphs @@ -45,11 +44,9 @@ def _sort_boxes(boxes: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Sort bounding boxes from top to bottom, left to right Args: - ---- boxes: bounding boxes of shape (N, 4) or (N, 4, 2) (in case of rotated bbox) Returns: - ------- tuple: indices of ordered boxes of shape (N,), boxes If straight boxes are passed tpo the function, boxes are unchanged else: boxes returned are straight boxes fitted to the straightened rotated boxes @@ -69,12 +66,10 @@ def _resolve_sub_lines(self, boxes: np.ndarray, word_idcs: List[int]) -> List[Li """Split a line in sub_lines Args: - ---- boxes: bounding boxes of shape (N, 4) word_idcs: list of indexes for the words of the line Returns: - ------- A list of (sub-)lines computed from the original line (words) """ lines = [] @@ -109,11 +104,9 @@ def _resolve_lines(self, boxes: np.ndarray) -> List[List[int]]: """Order boxes to group them in lines Args: - ---- boxes: bounding boxes of shape (N, 4) or (N, 4, 2) in case of rotated bbox Returns: - ------- nested list of box indices """ # Sort boxes, and straighten the boxes if they are rotated @@ -157,12 +150,10 @@ def _resolve_blocks(boxes: np.ndarray, lines: List[List[int]]) -> List[List[List """Order lines to group them in blocks Args: - ---- boxes: bounding boxes of shape (N, 4) or (N, 4, 2) lines: list of lines, each line is a list of idx Returns: - ------- nested list of box indices """ # Resolve enclosing boxes of lines @@ -230,7 +221,6 @@ def _build_blocks( """Gather independent words in structured blocks Args: - ---- boxes: bounding boxes of all detected words of the page, of shape (N, 4) or (N, 4, 2) objectness_scores: objectness scores of all detected words of the page, of shape N word_preds: list of all detected words of the page, of shape N @@ -238,7 +228,6 @@ def _build_blocks( the general orientation (orientations + confidences) of the crops Returns: - ------- list of block elements """ if boxes.shape[0] != len(word_preds): @@ -307,7 +296,6 @@ def __call__( """Re-arrange detected words into structured blocks Args: - ---- pages: list of N elements, where each element represents the page image boxes: list of N elements, where each element represents the localization predictions, of shape (*, 4) or (*, 4, 2) for all words for a given page @@ -322,7 +310,6 @@ def __call__( where each element is a dictionary containing the language (language + confidence) Returns: - ------- document object """ if len(boxes) != len(text_preds) != len(crop_orientations) != len(objectness_scores) or len(boxes) != len( @@ -374,7 +361,6 @@ class KIEDocumentBuilder(DocumentBuilder): """Implements a KIE document builder Args: - ---- resolve_lines: whether words should be automatically grouped into lines resolve_blocks: whether lines should be automatically grouped into blocks paragraph_break: relative length of the minimum space separating paragraphs @@ -396,7 +382,6 @@ def __call__( # type: ignore[override] """Re-arrange detected words into structured predictions Args: - ---- pages: list of N elements, where each element represents the page image boxes: list of N dictionaries, where each element represents the localization predictions for a class, of shape (*, 5) or (*, 6) for all predictions @@ -411,7 +396,6 @@ def __call__( # type: ignore[override] where each element is a dictionary containing the language (language + confidence) Returns: - ------- document object """ if len(boxes) != len(text_preds) != len(crop_orientations) != len(objectness_scores) or len(boxes) != len( @@ -477,14 +461,12 @@ def _build_blocks( # type: ignore[override] """Gather independent words in structured blocks Args: - ---- boxes: bounding boxes of all detected words of the page, of shape (N, 4) or (N, 4, 2) objectness_scores: objectness scores of all detected words of the page word_preds: list of all detected words of the page, of shape N crop_orientations: list of orientations for each word crop Returns: - ------- list of block elements """ if boxes.shape[0] != len(word_preds): diff --git a/doctr/models/classification/magc_resnet/__init__.py b/doctr/models/classification/magc_resnet/__init__.py index c7110f5669..38ab32543e 100644 --- a/doctr/models/classification/magc_resnet/__init__.py +++ b/doctr/models/classification/magc_resnet/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/classification/magc_resnet/pytorch.py b/doctr/models/classification/magc_resnet/pytorch.py index f503d7c7fa..e51c4b0fbf 100644 --- a/doctr/models/classification/magc_resnet/pytorch.py +++ b/doctr/models/classification/magc_resnet/pytorch.py @@ -36,7 +36,6 @@ class MAGC(nn.Module): `_. Args: - ---- inplanes: input channels headers: number of headers to split channels attn_scale: if True, re-scale attention to counteract the variance distibutions @@ -154,12 +153,10 @@ def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the ResNet architecture Returns: - ------- A feature extractor model """ return _magc_resnet( diff --git a/doctr/models/classification/magc_resnet/tensorflow.py b/doctr/models/classification/magc_resnet/tensorflow.py index d920ca44a4..442af37474 100644 --- a/doctr/models/classification/magc_resnet/tensorflow.py +++ b/doctr/models/classification/magc_resnet/tensorflow.py @@ -36,7 +36,6 @@ class MAGC(layers.Layer): `_. Args: - ---- inplanes: input channels headers: number of headers to split channels attn_scale: if True, re-scale attention to counteract the variance distibutions @@ -177,12 +176,10 @@ def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the ResNet architecture Returns: - ------- A feature extractor model """ return _magc_resnet( diff --git a/doctr/models/classification/mobilenet/__init__.py b/doctr/models/classification/mobilenet/__init__.py index 64556e403a..38ab32543e 100644 --- a/doctr/models/classification/mobilenet/__init__.py +++ b/doctr/models/classification/mobilenet/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): - from .tensorflow import * -elif is_torch_available(): +if is_torch_available(): from .pytorch import * +elif is_tf_available(): + from .tensorflow import * diff --git a/doctr/models/classification/mobilenet/pytorch.py b/doctr/models/classification/mobilenet/pytorch.py index 18470fdf11..fb8a1ac20c 100644 --- a/doctr/models/classification/mobilenet/pytorch.py +++ b/doctr/models/classification/mobilenet/pytorch.py @@ -123,12 +123,10 @@ def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.M >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the MobileNetV3 architecture Returns: - ------- a torch.nn.Module """ return _mobilenet_v3( @@ -148,12 +146,10 @@ def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3 >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the MobileNetV3 architecture Returns: - ------- a torch.nn.Module """ return _mobilenet_v3( @@ -177,12 +173,10 @@ def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.M >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the MobileNetV3 architecture Returns: - ------- a torch.nn.Module """ return _mobilenet_v3( @@ -205,12 +199,10 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3 >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the MobileNetV3 architecture Returns: - ------- a torch.nn.Module """ return _mobilenet_v3( @@ -234,12 +226,10 @@ def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the MobileNetV3 architecture Returns: - ------- a torch.nn.Module """ return _mobilenet_v3( @@ -262,12 +252,10 @@ def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any) >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the MobileNetV3 architecture Returns: - ------- a torch.nn.Module """ return _mobilenet_v3( diff --git a/doctr/models/classification/mobilenet/tensorflow.py b/doctr/models/classification/mobilenet/tensorflow.py index ae3535d947..6b6532a345 100644 --- a/doctr/models/classification/mobilenet/tensorflow.py +++ b/doctr/models/classification/mobilenet/tensorflow.py @@ -132,7 +132,6 @@ class InvertedResidual(layers.Layer): """InvertedResidual for mobilenet Args: - ---- conf: configuration object for inverted residual """ @@ -320,12 +319,10 @@ def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> MobileNetV3: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the MobileNetV3 architecture Returns: - ------- a keras.Model """ return _mobilenet_v3("mobilenet_v3_small", pretrained, False, **kwargs) @@ -343,12 +340,10 @@ def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3 >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the MobileNetV3 architecture Returns: - ------- a keras.Model """ return _mobilenet_v3("mobilenet_v3_small_r", pretrained, True, **kwargs) @@ -366,12 +361,10 @@ def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> MobileNetV3: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the MobileNetV3 architecture Returns: - ------- a keras.Model """ return _mobilenet_v3("mobilenet_v3_large", pretrained, False, **kwargs) @@ -389,12 +382,10 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3 >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the MobileNetV3 architecture Returns: - ------- a keras.Model """ return _mobilenet_v3("mobilenet_v3_large_r", pretrained, True, **kwargs) @@ -412,12 +403,10 @@ def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the MobileNetV3 architecture Returns: - ------- a keras.Model """ return _mobilenet_v3("mobilenet_v3_small_crop_orientation", pretrained, include_top=True, **kwargs) @@ -435,12 +424,10 @@ def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any) >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the MobileNetV3 architecture Returns: - ------- a keras.Model """ return _mobilenet_v3("mobilenet_v3_small_page_orientation", pretrained, include_top=True, **kwargs) diff --git a/doctr/models/classification/predictor/__init__.py b/doctr/models/classification/predictor/__init__.py index c7110f5669..5bc3a74b47 100644 --- a/doctr/models/classification/predictor/__init__.py +++ b/doctr/models/classification/predictor/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): - from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): + from .tensorflow import * # type: ignore[assignment] diff --git a/doctr/models/classification/predictor/pytorch.py b/doctr/models/classification/predictor/pytorch.py index 96f5c468ff..7a3e73af7a 100644 --- a/doctr/models/classification/predictor/pytorch.py +++ b/doctr/models/classification/predictor/pytorch.py @@ -20,7 +20,6 @@ class OrientationPredictor(nn.Module): 4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise. Args: - ---- pre_processor: transform inputs for easier batched model inference model: core classification architecture (backbone + classification head) """ diff --git a/doctr/models/classification/predictor/tensorflow.py b/doctr/models/classification/predictor/tensorflow.py index 23efbf6579..ec1337c1ec 100644 --- a/doctr/models/classification/predictor/tensorflow.py +++ b/doctr/models/classification/predictor/tensorflow.py @@ -20,7 +20,6 @@ class OrientationPredictor(NestedObject): 4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise. Args: - ---- pre_processor: transform inputs for easier batched model inference model: core classification architecture (backbone + classification head) """ diff --git a/doctr/models/classification/resnet/__init__.py b/doctr/models/classification/resnet/__init__.py index c7110f5669..5bc3a74b47 100644 --- a/doctr/models/classification/resnet/__init__.py +++ b/doctr/models/classification/resnet/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): - from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): + from .tensorflow import * # type: ignore[assignment] diff --git a/doctr/models/classification/resnet/pytorch.py b/doctr/models/classification/resnet/pytorch.py index 7591741c29..10fea5302d 100644 --- a/doctr/models/classification/resnet/pytorch.py +++ b/doctr/models/classification/resnet/pytorch.py @@ -84,7 +84,6 @@ class ResNet(nn.Sequential): Text Recognition" `_. Args: - ---- num_blocks: number of resnet block in each stage output_channels: number of channels in each stage stage_conv: whether to add a conv_sequence after each stage @@ -224,12 +223,10 @@ def resnet18(pretrained: bool = False, **kwargs: Any) -> TVResNet: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the ResNet architecture Returns: - ------- A resnet18 model """ return _tv_resnet( @@ -253,12 +250,10 @@ def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the ResNet architecture Returns: - ------- A resnet31 model """ return _resnet( @@ -287,12 +282,10 @@ def resnet34(pretrained: bool = False, **kwargs: Any) -> TVResNet: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the ResNet architecture Returns: - ------- A resnet34 model """ return _tv_resnet( @@ -315,12 +308,10 @@ def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the ResNet architecture Returns: - ------- A resnet34_wide model """ return _resnet( @@ -349,12 +340,10 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> TVResNet: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the ResNet architecture Returns: - ------- A resnet50 model """ return _tv_resnet( diff --git a/doctr/models/classification/resnet/tensorflow.py b/doctr/models/classification/resnet/tensorflow.py index 662a43c3a0..b800272a7d 100644 --- a/doctr/models/classification/resnet/tensorflow.py +++ b/doctr/models/classification/resnet/tensorflow.py @@ -61,7 +61,6 @@ class ResnetBlock(layers.Layer): """Implements a resnet31 block with shortcut Args: - ---- conv_shortcut: Use of shortcut output_channels: number of channels to use in Conv2D kernel_size: size of square kernels @@ -121,7 +120,6 @@ class ResNet(Sequential): """Implements a ResNet architecture Args: - ---- num_blocks: number of resnet block in each stage output_channels: number of channels in each stage stage_downsample: whether the first residual block of a stage should downsample @@ -234,12 +232,10 @@ def resnet18(pretrained: bool = False, **kwargs: Any) -> ResNet: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the ResNet architecture Returns: - ------- A classification model """ return _resnet( @@ -267,12 +263,10 @@ def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the ResNet architecture Returns: - ------- A classification model """ return _resnet( @@ -300,12 +294,10 @@ def resnet34(pretrained: bool = False, **kwargs: Any) -> ResNet: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the ResNet architecture Returns: - ------- A classification model """ return _resnet( @@ -332,12 +324,10 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the ResNet architecture Returns: - ------- A classification model """ kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs["resnet50"]["classes"])) @@ -386,12 +376,10 @@ def resnet34_wide(pretrained: bool = False, **kwargs: Any) -> ResNet: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the ResNet architecture Returns: - ------- A classification model """ return _resnet( diff --git a/doctr/models/classification/textnet/__init__.py b/doctr/models/classification/textnet/__init__.py index c7110f5669..38ab32543e 100644 --- a/doctr/models/classification/textnet/__init__.py +++ b/doctr/models/classification/textnet/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/classification/textnet/pytorch.py b/doctr/models/classification/textnet/pytorch.py index cdbb719f8b..5dabb7586f 100644 --- a/doctr/models/classification/textnet/pytorch.py +++ b/doctr/models/classification/textnet/pytorch.py @@ -47,7 +47,6 @@ class TextNet(nn.Sequential): Implementation based on the official Pytorch implementation: `_. Args: - ---- stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage. include_top (bool, optional): Whether to include the classifier head. Defaults to True. num_classes (int, optional): Number of output classes. Defaults to 1000. @@ -135,12 +134,10 @@ def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the TextNet architecture Returns: - ------- A textnet tiny model """ return _textnet( @@ -184,12 +181,10 @@ def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the TextNet architecture Returns: - ------- A TextNet small model """ return _textnet( @@ -233,12 +228,10 @@ def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the TextNet architecture Returns: - ------- A TextNet base model """ return _textnet( diff --git a/doctr/models/classification/textnet/tensorflow.py b/doctr/models/classification/textnet/tensorflow.py index e5e6105a7e..8e11f66435 100644 --- a/doctr/models/classification/textnet/tensorflow.py +++ b/doctr/models/classification/textnet/tensorflow.py @@ -47,7 +47,6 @@ class TextNet(Sequential): Implementation based on the official Pytorch implementation: `_. Args: - ---- stages (List[Dict[str, List[int]]]): List of dictionaries containing the parameters of each stage. include_top (bool, optional): Whether to include the classifier head. Defaults to True. num_classes (int, optional): Number of output classes. Defaults to 1000. @@ -136,12 +135,10 @@ def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the TextNet architecture Returns: - ------- A textnet tiny model """ return _textnet( @@ -184,12 +181,10 @@ def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the TextNet architecture Returns: - ------- A TextNet small model """ return _textnet( @@ -232,12 +227,10 @@ def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the TextNet architecture Returns: - ------- A TextNet base model """ return _textnet( diff --git a/doctr/models/classification/vgg/__init__.py b/doctr/models/classification/vgg/__init__.py index 64556e403a..38ab32543e 100644 --- a/doctr/models/classification/vgg/__init__.py +++ b/doctr/models/classification/vgg/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): - from .tensorflow import * -elif is_torch_available(): +if is_torch_available(): from .pytorch import * +elif is_tf_available(): + from .tensorflow import * diff --git a/doctr/models/classification/vgg/pytorch.py b/doctr/models/classification/vgg/pytorch.py index 2e16b11788..2bea77ef14 100644 --- a/doctr/models/classification/vgg/pytorch.py +++ b/doctr/models/classification/vgg/pytorch.py @@ -77,12 +77,10 @@ def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> tv_vgg.VGG: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on ImageNet **kwargs: keyword arguments of the VGG architecture Returns: - ------- VGG feature extractor """ return _vgg( diff --git a/doctr/models/classification/vgg/tensorflow.py b/doctr/models/classification/vgg/tensorflow.py index c42e369bcd..74c991fa4b 100644 --- a/doctr/models/classification/vgg/tensorflow.py +++ b/doctr/models/classification/vgg/tensorflow.py @@ -32,7 +32,6 @@ class VGG(Sequential): `_. Args: - ---- num_blocks: number of convolutional block in each stage planes: number of output channels in each stage rect_pools: whether pooling square kernels should be replace with rectangular ones @@ -106,12 +105,10 @@ def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> VGG: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on ImageNet **kwargs: keyword arguments of the VGG architecture Returns: - ------- VGG feature extractor """ return _vgg( diff --git a/doctr/models/classification/vit/__init__.py b/doctr/models/classification/vit/__init__.py index c7110f5669..38ab32543e 100644 --- a/doctr/models/classification/vit/__init__.py +++ b/doctr/models/classification/vit/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/classification/vit/pytorch.py b/doctr/models/classification/vit/pytorch.py index 335e92559f..f63eef8d13 100644 --- a/doctr/models/classification/vit/pytorch.py +++ b/doctr/models/classification/vit/pytorch.py @@ -40,7 +40,6 @@ class ClassifierHead(nn.Module): """Classifier head for Vision Transformer Args: - ---- in_channels: number of input channels num_classes: number of output classes """ @@ -65,7 +64,6 @@ class VisionTransformer(nn.Sequential): `_. Args: - ---- d_model: dimension of the transformer layers num_layers: number of transformer layers num_heads: number of attention heads @@ -143,12 +141,10 @@ def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the VisionTransformer architecture Returns: - ------- A feature extractor model """ return _vit( @@ -175,12 +171,10 @@ def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the VisionTransformer architecture Returns: - ------- A feature extractor model """ return _vit( diff --git a/doctr/models/classification/vit/tensorflow.py b/doctr/models/classification/vit/tensorflow.py index 386065bca6..572877bc7e 100644 --- a/doctr/models/classification/vit/tensorflow.py +++ b/doctr/models/classification/vit/tensorflow.py @@ -41,7 +41,6 @@ class ClassifierHead(layers.Layer, NestedObject): """Classifier head for Vision Transformer Args: - ---- num_classes: number of output classes """ @@ -61,7 +60,6 @@ class VisionTransformer(Sequential): `_. Args: - ---- d_model: dimension of the transformer layers num_layers: number of transformer layers num_heads: number of attention heads @@ -148,12 +146,10 @@ def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the VisionTransformer architecture Returns: - ------- A feature extractor model """ return _vit( @@ -179,12 +175,10 @@ def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer: >>> out = model(input_tensor) Args: - ---- pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the VisionTransformer architecture Returns: - ------- A feature extractor model """ return _vit( diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py index 7df839484b..13050a9d9d 100644 --- a/doctr/models/classification/zoo.py +++ b/doctr/models/classification/zoo.py @@ -74,14 +74,12 @@ def crop_orientation_predictor( >>> out = model([input_crop]) Args: - ---- arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation') pretrained: If True, returns a model pre-trained on our recognition crops dataset batch_size: number of samples the model processes in parallel **kwargs: keyword arguments to be passed to the OrientationPredictor Returns: - ------- OrientationPredictor """ return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="crop", **kwargs) @@ -99,14 +97,12 @@ def page_orientation_predictor( >>> out = model([input_page]) Args: - ---- arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation') pretrained: If True, returns a model pre-trained on our recognition crops dataset batch_size: number of samples the model processes in parallel **kwargs: keyword arguments to be passed to the OrientationPredictor Returns: - ------- OrientationPredictor """ return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="page", **kwargs) diff --git a/doctr/models/detection/_utils/__init__.py b/doctr/models/detection/_utils/__init__.py index fbeba301bc..0a48729503 100644 --- a/doctr/models/detection/_utils/__init__.py +++ b/doctr/models/detection/_utils/__init__.py @@ -1,7 +1,7 @@ -from doctr.file_utils import is_tf_available +from doctr.file_utils import is_tf_available, is_torch_available from .base import * -if is_tf_available(): - from .tensorflow import * -else: +if is_torch_available(): from .pytorch import * +elif is_tf_available(): + from .tensorflow import * diff --git a/doctr/models/detection/_utils/base.py b/doctr/models/detection/_utils/base.py index 71fdd2759e..86f5caebaf 100644 --- a/doctr/models/detection/_utils/base.py +++ b/doctr/models/detection/_utils/base.py @@ -20,7 +20,6 @@ def _remove_padding( """Remove padding from the localization predictions Args: - ---- pages: list of pages loc_preds: list of localization predictions preserve_aspect_ratio: whether the aspect ratio was preserved during padding @@ -28,7 +27,6 @@ def _remove_padding( assume_straight_pages: whether the pages are assumed to be straight Returns: - ------- list of unpaded localization predictions """ if preserve_aspect_ratio: diff --git a/doctr/models/detection/_utils/pytorch.py b/doctr/models/detection/_utils/pytorch.py index 0ac99f4690..a0ef618422 100644 --- a/doctr/models/detection/_utils/pytorch.py +++ b/doctr/models/detection/_utils/pytorch.py @@ -13,12 +13,10 @@ def erode(x: Tensor, kernel_size: int) -> Tensor: """Performs erosion on a given tensor Args: - ---- x: boolean tensor of shape (N, C, H, W) kernel_size: the size of the kernel to use for erosion Returns: - ------- the eroded tensor """ _pad = (kernel_size - 1) // 2 @@ -30,12 +28,10 @@ def dilate(x: Tensor, kernel_size: int) -> Tensor: """Performs dilation on a given tensor Args: - ---- x: boolean tensor of shape (N, C, H, W) kernel_size: the size of the kernel to use for dilation Returns: - ------- the dilated tensor """ _pad = (kernel_size - 1) // 2 diff --git a/doctr/models/detection/_utils/tensorflow.py b/doctr/models/detection/_utils/tensorflow.py index 6f5ec21749..b02ed7cd22 100644 --- a/doctr/models/detection/_utils/tensorflow.py +++ b/doctr/models/detection/_utils/tensorflow.py @@ -12,12 +12,10 @@ def erode(x: tf.Tensor, kernel_size: int) -> tf.Tensor: """Performs erosion on a given tensor Args: - ---- x: boolean tensor of shape (N, H, W, C) kernel_size: the size of the kernel to use for erosion Returns: - ------- the eroded tensor """ return 1 - tf.nn.max_pool2d(1 - x, kernel_size, strides=1, padding="SAME") @@ -27,12 +25,10 @@ def dilate(x: tf.Tensor, kernel_size: int) -> tf.Tensor: """Performs dilation on a given tensor Args: - ---- x: boolean tensor of shape (N, H, W, C) kernel_size: the size of the kernel to use for dilation Returns: - ------- the dilated tensor """ return tf.nn.max_pool2d(x, kernel_size, strides=1, padding="SAME") diff --git a/doctr/models/detection/core.py b/doctr/models/detection/core.py index 63fa786151..2c3189f57f 100644 --- a/doctr/models/detection/core.py +++ b/doctr/models/detection/core.py @@ -17,7 +17,6 @@ class DetectionPostProcessor(NestedObject): """Abstract class to postprocess the raw output of the model Args: - ---- box_thresh (float): minimal objectness score to consider a box bin_thresh (float): threshold to apply to segmentation raw heatmap assume straight_pages (bool): if True, fit straight boxes only @@ -37,13 +36,11 @@ def box_score(pred: np.ndarray, points: np.ndarray, assume_straight_pages: bool """Compute the confidence score for a polygon : mean of the p values on the polygon Args: - ---- pred (np.ndarray): p map returned by the model points: coordinates of the polygon assume_straight_pages: if True, fit straight boxes only Returns: - ------- polygon objectness """ h, w = pred.shape[:2] @@ -75,11 +72,9 @@ def __call__( """Performs postprocessing for a list of model outputs Args: - ---- proba_map: probability map of shape (N, H, W, C) Returns: - ------- list of N class predictions (for each input sample), where each class predictions is a list of C tensors of shape (*, 5) or (*, 6) """ diff --git a/doctr/models/detection/differentiable_binarization/__init__.py b/doctr/models/detection/differentiable_binarization/__init__.py index c7110f5669..5bc3a74b47 100644 --- a/doctr/models/detection/differentiable_binarization/__init__.py +++ b/doctr/models/detection/differentiable_binarization/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): - from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): + from .tensorflow import * # type: ignore[assignment] diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index 414471146c..08e6339967 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -22,7 +22,6 @@ class DBPostProcessor(DetectionPostProcessor): `_. Args: - ---- unclip ratio: ratio used to unshrink polygons min_size_box: minimal length (pix) to keep a box max_candidates: maximum boxes to consider in a single page @@ -47,11 +46,9 @@ def polygon_to_box( """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon Args: - ---- points: The first parameter. Returns: - ------- a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle) """ if not self.assume_straight_pages: @@ -96,14 +93,12 @@ def bitmap_to_boxes( """Compute boxes from a bitmap/pred_map: find connected components then filter boxes Args: - ---- pred: Pred map from differentiable binarization output bitmap: Bitmap map computed from pred (binarized) angle_tol: Comparison tolerance of the angle with the median angle across the page ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop Returns: - ------- np tensor boxes for the bitmap, each box is a 5-element list containing x, y, w, h, score for the box """ @@ -164,7 +159,6 @@ class _DBNet: `_. Args: - ---- feature extractor: the backbone serving as feature extractor fpn_channels: number of channels each extracted feature maps is mapped to """ @@ -186,7 +180,6 @@ def compute_distance( """Compute the distance for each point of the map (xs, ys) to the (a, b) segment Args: - ---- xs : map of x coordinates (height, width) ys : map of y coordinates (height, width) a: first point defining the [ab] segment @@ -194,7 +187,6 @@ def compute_distance( eps: epsilon to avoid division by zero Returns: - ------- The computed distance """ @@ -218,7 +210,6 @@ def draw_thresh_map( """Draw a polygon treshold map on a canvas, as described in the DB paper Args: - ---- polygon : array of coord., to draw the boundary of the polygon canvas : threshold map to fill with polygons mask : mask for training on threshold polygons diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 427ada0e31..cb9daab2dc 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -96,7 +96,6 @@ class DBNet(_DBNet, nn.Module): `_. Args: - ---- feature extractor: the backbone serving as feature extractor head_chans: the number of channels in the head deform_conv: whether to use deformable convolution @@ -231,7 +230,6 @@ def compute_loss( and a list of masks for each image. From there it computes the loss with the model output Args: - ---- out_map: output feature map of the model of shape (N, C, H, W) thresh_map: threshold map of shape (N, C, H, W) target: list of dictionary where each dict has a `boxes` and a `flags` entry @@ -240,7 +238,6 @@ def compute_loss( eps: epsilon factor in dice loss Returns: - ------- A loss tensor """ if gamma < 0: @@ -341,12 +338,10 @@ def db_resnet34(pretrained: bool = False, **kwargs: Any) -> DBNet: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the DBNet architecture Returns: - ------- text detection architecture """ return _dbnet( @@ -376,12 +371,10 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the DBNet architecture Returns: - ------- text detection architecture """ return _dbnet( @@ -411,12 +404,10 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the DBNet architecture Returns: - ------- text detection architecture """ return _dbnet( diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index b0ca1f08e5..dc94977fb8 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -50,7 +50,6 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject): `_. Args: - ---- channels: number of channel to output """ @@ -72,12 +71,10 @@ def build_upsampling( """Module which performs a 3x3 convolution followed by up-sampling Args: - ---- channels: number of output channels dilation_factor (int): dilation factor to scale the convolution output before concatenation Returns: - ------- a keras.layers.Layer object, wrapping these operations in a sequential module """ @@ -114,7 +111,6 @@ class DBNet(_DBNet, Model, NestedObject): `_. Args: - ---- feature extractor: the backbone serving as feature extractor fpn_channels: number of channels each extracted feature maps is mapped to bin_thresh: threshold for binarization @@ -184,7 +180,6 @@ def compute_loss( and a list of masks for each image. From there it computes the loss with the model output Args: - ---- out_map: output feature map of the model of shape (N, H, W, C) thresh_map: threshold map of shape (N, H, W, C) target: list of dictionary where each dict has a `boxes` and a `flags` entry @@ -193,7 +188,6 @@ def compute_loss( eps: epsilon factor in dice loss Returns: - ------- A loss tensor """ if gamma < 0: @@ -379,12 +373,10 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the DBNet architecture Returns: - ------- text detection architecture """ return _db_resnet( @@ -407,12 +399,10 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the DBNet architecture Returns: - ------- text detection architecture """ return _db_mobilenet( diff --git a/doctr/models/detection/fast/__init__.py b/doctr/models/detection/fast/__init__.py index c7110f5669..5bc3a74b47 100644 --- a/doctr/models/detection/fast/__init__.py +++ b/doctr/models/detection/fast/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): - from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): + from .tensorflow import * # type: ignore[assignment] diff --git a/doctr/models/detection/fast/base.py b/doctr/models/detection/fast/base.py index 1b3a02bb29..85970fc5be 100644 --- a/doctr/models/detection/fast/base.py +++ b/doctr/models/detection/fast/base.py @@ -23,7 +23,6 @@ class FASTPostProcessor(DetectionPostProcessor): """Implements a post processor for FAST model. Args: - ---- bin_thresh: threshold used to binzarized p_map at inference time box_thresh: minimal objectness score to consider a box assume_straight_pages: whether the inputs were expected to have horizontal text elements @@ -45,11 +44,9 @@ def polygon_to_box( """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon Args: - ---- points: The first parameter. Returns: - ------- a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle) """ if not self.assume_straight_pages: @@ -94,14 +91,12 @@ def bitmap_to_boxes( """Compute boxes from a bitmap/pred_map: find connected components then filter boxes Args: - ---- pred: Pred map from differentiable linknet output bitmap: Bitmap map computed from pred (binarized) angle_tol: Comparison tolerance of the angle with the median angle across the page ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop Returns: - ------- np tensor boxes for the bitmap, each box is a 6-element list containing x, y, w, h, alpha, score for the box """ @@ -165,13 +160,11 @@ def build_target( """Build the target, and it's mask to be used from loss computation. Args: - ---- target: target coming from dataset output_shape: shape of the output of the model without batch_size channels_last: whether channels are last or not Returns: - ------- the new formatted target, mask and shrunken text kernel """ if any(t.dtype != np.float32 for tgt in target for t in tgt.values()): diff --git a/doctr/models/detection/fast/pytorch.py b/doctr/models/detection/fast/pytorch.py index 5ac44b1825..c7fd98b098 100644 --- a/doctr/models/detection/fast/pytorch.py +++ b/doctr/models/detection/fast/pytorch.py @@ -47,7 +47,6 @@ class FastNeck(nn.Module): """Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layers. Args: - ---- in_channels: number of input channels out_channels: number of output channels """ @@ -77,7 +76,6 @@ class FastHead(nn.Sequential): """Head of the FAST architecture Args: - ---- in_channels: number of input channels num_classes: number of output classes out_channels: number of output channels @@ -104,7 +102,6 @@ class FAST(_FAST, nn.Module): `_. Args: - ---- feat extractor: the backbone serving as feature extractor bin_thresh: threshold for binarization box_thresh: minimal objectness score to consider a box @@ -219,13 +216,11 @@ def compute_loss( """Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5. Args: - ---- out_map: output feature map of the model of shape (N, num_classes, H, W) target: list of dictionary where each dict has a `boxes` and a `flags` entry eps: epsilon factor in dice loss Returns: - ------- A loss tensor """ targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type] @@ -282,12 +277,10 @@ def ohem_sample(score: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor) -> to def reparameterize(model: Union[FAST, nn.Module]) -> FAST: """Fuse batchnorm and conv layers and reparameterize the model - args: - ---- + Args: model: the FAST model to reparameterize Returns: - ------- the reparameterized model """ last_conv = None @@ -366,12 +359,10 @@ def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the DBNet architecture Returns: - ------- text detection architecture """ return _fast( @@ -395,12 +386,10 @@ def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the DBNet architecture Returns: - ------- text detection architecture """ return _fast( @@ -424,12 +413,10 @@ def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the DBNet architecture Returns: - ------- text detection architecture """ return _fast( diff --git a/doctr/models/detection/fast/tensorflow.py b/doctr/models/detection/fast/tensorflow.py index b0043494ed..231f0d8dd3 100644 --- a/doctr/models/detection/fast/tensorflow.py +++ b/doctr/models/detection/fast/tensorflow.py @@ -49,7 +49,6 @@ class FastNeck(layers.Layer, NestedObject): """Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layer. Args: - ---- in_channels: number of input channels out_channels: number of output channels """ @@ -77,7 +76,6 @@ class FastHead(Sequential): """Head of the FAST architecture Args: - ---- in_channels: number of input channels num_classes: number of output classes out_channels: number of output channels @@ -104,7 +102,6 @@ class FAST(_FAST, Model, NestedObject): `_. Args: - ---- feature extractor: the backbone serving as feature extractor bin_thresh: threshold for binarization box_thresh: minimal objectness score to consider a box @@ -165,13 +162,11 @@ def compute_loss( """Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5. Args: - ---- out_map: output feature map of the model of shape (N, num_classes, H, W) target: list of dictionary where each dict has a `boxes` and a `flags` entry eps: epsilon factor in dice loss Returns: - ------- A loss tensor """ targets = self.build_target(target, out_map.shape[1:], True) @@ -259,11 +254,10 @@ def reparameterize(model: Union[FAST, layers.Layer]) -> FAST: """Fuse batchnorm and conv layers and reparameterize the model args: - ---- + model: the FAST model to reparameterize Returns: - ------- the reparameterized model """ last_conv = None @@ -358,12 +352,10 @@ def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the DBNet architecture Returns: - ------- text detection architecture """ return _fast( @@ -386,12 +378,10 @@ def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the DBNet architecture Returns: - ------- text detection architecture """ return _fast( @@ -414,12 +404,10 @@ def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the DBNet architecture Returns: - ------- text detection architecture """ return _fast( diff --git a/doctr/models/detection/linknet/__init__.py b/doctr/models/detection/linknet/__init__.py index c7110f5669..5bc3a74b47 100644 --- a/doctr/models/detection/linknet/__init__.py +++ b/doctr/models/detection/linknet/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): - from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): + from .tensorflow import * # type: ignore[assignment] diff --git a/doctr/models/detection/linknet/base.py b/doctr/models/detection/linknet/base.py index 9aeb543fe3..782d688f1b 100644 --- a/doctr/models/detection/linknet/base.py +++ b/doctr/models/detection/linknet/base.py @@ -23,7 +23,6 @@ class LinkNetPostProcessor(DetectionPostProcessor): """Implements a post processor for LinkNet model. Args: - ---- bin_thresh: threshold used to binzarized p_map at inference time box_thresh: minimal objectness score to consider a box assume_straight_pages: whether the inputs were expected to have horizontal text elements @@ -45,11 +44,9 @@ def polygon_to_box( """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon Args: - ---- points: The first parameter. Returns: - ------- a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle) """ if not self.assume_straight_pages: @@ -94,14 +91,12 @@ def bitmap_to_boxes( """Compute boxes from a bitmap/pred_map: find connected components then filter boxes Args: - ---- pred: Pred map from differentiable linknet output bitmap: Bitmap map computed from pred (binarized) angle_tol: Comparison tolerance of the angle with the median angle across the page ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop Returns: - ------- np tensor boxes for the bitmap, each box is a 6-element list containing x, y, w, h, alpha, score for the box """ @@ -152,7 +147,6 @@ class _LinkNet(BaseModel): `_. Args: - ---- out_chan: number of channels for the output """ @@ -169,13 +163,11 @@ def build_target( """Build the target, and it's mask to be used from loss computation. Args: - ---- target: target coming from dataset output_shape: shape of the output of the model without batch_size channels_last: whether channels are last or not Returns: - ------- the new formatted target and the mask """ if any(t.dtype != np.float32 for tgt in target for t in tgt.values()): diff --git a/doctr/models/detection/linknet/pytorch.py b/doctr/models/detection/linknet/pytorch.py index 537fd57256..321c894d54 100644 --- a/doctr/models/detection/linknet/pytorch.py +++ b/doctr/models/detection/linknet/pytorch.py @@ -89,7 +89,6 @@ class LinkNet(nn.Module, _LinkNet): `_. Args: - ---- feature extractor: the backbone serving as feature extractor bin_thresh: threshold for binarization of the output feature map box_thresh: minimal objectness score to consider a box @@ -207,7 +206,6 @@ def compute_loss( `_. Args: - ---- out_map: output feature map of the model of shape (N, num_classes, H, W) target: list of dictionary where each dict has a `boxes` and a `flags` entry gamma: modulating factor in the focal loss formula @@ -215,7 +213,6 @@ def compute_loss( eps: epsilon factor in dice loss Returns: - ------- A loss tensor """ _target, _mask = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type] @@ -295,12 +292,10 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the LinkNet architecture Returns: - ------- text detection architecture """ return _linknet( @@ -327,12 +322,10 @@ def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the LinkNet architecture Returns: - ------- text detection architecture """ return _linknet( @@ -359,12 +352,10 @@ def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the LinkNet architecture Returns: - ------- text detection architecture """ return _linknet( diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index 9c991c6f4c..502531b430 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -100,7 +100,6 @@ class LinkNet(_LinkNet, Model): `_. Args: - ---- feature extractor: the backbone serving as feature extractor fpn_channels: number of channels each extracted feature maps is mapped to bin_thresh: threshold for binarization of the output feature map @@ -176,7 +175,6 @@ def compute_loss( `_. Args: - ---- out_map: output feature map of the model of shape N x H x W x 1 target: list of dictionary where each dict has a `boxes` and a `flags` entry gamma: modulating factor in the focal loss formula @@ -184,7 +182,6 @@ def compute_loss( eps: epsilon factor in dice loss Returns: - ------- A loss tensor """ seg_target, seg_mask = self.build_target(target, out_map.shape[1:], True) @@ -305,12 +302,10 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the LinkNet architecture Returns: - ------- text detection architecture """ return _linknet( @@ -333,12 +328,10 @@ def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the LinkNet architecture Returns: - ------- text detection architecture """ return _linknet( @@ -361,12 +354,10 @@ def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the LinkNet architecture Returns: - ------- text detection architecture """ return _linknet( diff --git a/doctr/models/detection/predictor/__init__.py b/doctr/models/detection/predictor/__init__.py index ff30c3b2e7..5bc3a74b47 100644 --- a/doctr/models/detection/predictor/__init__.py +++ b/doctr/models/detection/predictor/__init__.py @@ -1,6 +1,6 @@ -from doctr.file_utils import is_tf_available +from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): - from .tensorflow import * -else: - from .pytorch import * # type: ignore[assignment] +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): + from .tensorflow import * # type: ignore[assignment] diff --git a/doctr/models/detection/predictor/pytorch.py b/doctr/models/detection/predictor/pytorch.py index 8bac391618..257164d4b6 100644 --- a/doctr/models/detection/predictor/pytorch.py +++ b/doctr/models/detection/predictor/pytorch.py @@ -20,7 +20,6 @@ class DetectionPredictor(nn.Module): """Implements an object able to localize text elements in a document Args: - ---- pre_processor: transform inputs for easier batched model inference model: core detection architecture """ diff --git a/doctr/models/detection/predictor/tensorflow.py b/doctr/models/detection/predictor/tensorflow.py index a3d5085847..5263a560d3 100644 --- a/doctr/models/detection/predictor/tensorflow.py +++ b/doctr/models/detection/predictor/tensorflow.py @@ -20,7 +20,6 @@ class DetectionPredictor(NestedObject): """Implements an object able to localize text elements in a document Args: - ---- pre_processor: transform inputs for easier batched model inference model: core detection architecture """ diff --git a/doctr/models/detection/zoo.py b/doctr/models/detection/zoo.py index b8dfa7636d..8c54503a41 100644 --- a/doctr/models/detection/zoo.py +++ b/doctr/models/detection/zoo.py @@ -93,7 +93,6 @@ def detection_predictor( >>> out = model([input_page]) Args: - ---- arch: name of the architecture or model itself to use (e.g. 'db_resnet50') pretrained: If True, returns a model pre-trained on our text detection dataset assume_straight_pages: If True, fit straight boxes to the page @@ -104,7 +103,6 @@ def detection_predictor( **kwargs: optional keyword arguments passed to the architecture Returns: - ------- Detection predictor """ return _predictor( diff --git a/doctr/models/factory/hub.py b/doctr/models/factory/hub.py index dd9fc5d776..9ee2427105 100644 --- a/doctr/models/factory/hub.py +++ b/doctr/models/factory/hub.py @@ -61,7 +61,6 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task """Save model and config to disk for pushing to huggingface hub Args: - ---- model: TF or PyTorch model to be saved save_dir: directory to save model and config arch: architecture name @@ -97,7 +96,6 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: # >>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small') Args: - ---- model: TF or PyTorch model to be saved model_name: name of the model which is also the repository name task: task name @@ -114,9 +112,9 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: # # default readme readme = textwrap.dedent( f""" - --- + language: en - --- +

@@ -190,12 +188,10 @@ def from_hub(repo_id: str, **kwargs: Any): >>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn") Args: - ---- repo_id: HuggingFace model hub repo kwargs: kwargs of `hf_hub_download` or `snapshot_download` Returns: - ------- Model loaded with the checkpoint """ # Get the config diff --git a/doctr/models/kie_predictor/__init__.py b/doctr/models/kie_predictor/__init__.py index ff30c3b2e7..5bc3a74b47 100644 --- a/doctr/models/kie_predictor/__init__.py +++ b/doctr/models/kie_predictor/__init__.py @@ -1,6 +1,6 @@ -from doctr.file_utils import is_tf_available +from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): - from .tensorflow import * -else: - from .pytorch import * # type: ignore[assignment] +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): + from .tensorflow import * # type: ignore[assignment] diff --git a/doctr/models/kie_predictor/base.py b/doctr/models/kie_predictor/base.py index 0b6cd28dc7..c8ade54579 100644 --- a/doctr/models/kie_predictor/base.py +++ b/doctr/models/kie_predictor/base.py @@ -17,7 +17,6 @@ class _KIEPredictor(_OCRPredictor): """Implements an object able to localize and identify text elements in a set of documents Args: - ---- assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages without rotated textual elements. straighten_pages: if True, estimates the page general orientation based on the median line orientation. diff --git a/doctr/models/kie_predictor/pytorch.py b/doctr/models/kie_predictor/pytorch.py index c7ffa140c5..61ab910241 100644 --- a/doctr/models/kie_predictor/pytorch.py +++ b/doctr/models/kie_predictor/pytorch.py @@ -24,7 +24,6 @@ class KIEPredictor(nn.Module, _KIEPredictor): """Implements an object able to localize and identify text elements in a set of documents Args: - ---- det_predictor: detection module reco_predictor: recognition module assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages @@ -52,8 +51,8 @@ def __init__( **kwargs: Any, ) -> None: nn.Module.__init__(self) - self.det_predictor = det_predictor.eval() # type: ignore[attr-defined] - self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined] + self.det_predictor = det_predictor.eval() + self.reco_predictor = reco_predictor.eval() _KIEPredictor.__init__( self, assume_straight_pages, diff --git a/doctr/models/kie_predictor/tensorflow.py b/doctr/models/kie_predictor/tensorflow.py index b73f651fc5..3f0d58bbfc 100644 --- a/doctr/models/kie_predictor/tensorflow.py +++ b/doctr/models/kie_predictor/tensorflow.py @@ -24,7 +24,6 @@ class KIEPredictor(NestedObject, _KIEPredictor): """Implements an object able to localize and identify text elements in a set of documents Args: - ---- det_predictor: detection module reco_predictor: recognition module assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages @@ -103,7 +102,7 @@ def __call__( origin_page_shapes = [page.shape[:2] for page in pages] # Forward again to get predictions on straight pages - loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment] + loc_preds = self.det_predictor(pages, **kwargs) dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore diff --git a/doctr/models/modules/layers/__init__.py b/doctr/models/modules/layers/__init__.py index c7110f5669..38ab32543e 100644 --- a/doctr/models/modules/layers/__init__.py +++ b/doctr/models/modules/layers/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/modules/transformer/__init__.py b/doctr/models/modules/transformer/__init__.py index c7110f5669..38ab32543e 100644 --- a/doctr/models/modules/transformer/__init__.py +++ b/doctr/models/modules/transformer/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/modules/transformer/pytorch.py b/doctr/models/modules/transformer/pytorch.py index 6f1c612978..c7c1f113a3 100644 --- a/doctr/models/modules/transformer/pytorch.py +++ b/doctr/models/modules/transformer/pytorch.py @@ -33,11 +33,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass Args: - ---- x: embeddings (batch, max_len, d_model) - Returns - ------- + Returns: positional embeddings (batch, max_len, d_model) """ x = x + self.pe[:, : x.size(1)] diff --git a/doctr/models/modules/transformer/tensorflow.py b/doctr/models/modules/transformer/tensorflow.py index 50c7cef04d..79584356c1 100644 --- a/doctr/models/modules/transformer/tensorflow.py +++ b/doctr/models/modules/transformer/tensorflow.py @@ -43,12 +43,10 @@ def call( """Forward pass Args: - ---- x: embeddings (batch, max_len, d_model) **kwargs: additional arguments - Returns - ------- + Returns: positional embeddings (batch, max_len, d_model) """ if x.dtype == tf.float16: # amp fix: cast to half diff --git a/doctr/models/modules/vision_transformer/__init__.py b/doctr/models/modules/vision_transformer/__init__.py index c7110f5669..38ab32543e 100644 --- a/doctr/models/modules/vision_transformer/__init__.py +++ b/doctr/models/modules/vision_transformer/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/predictor/__init__.py b/doctr/models/predictor/__init__.py index ff30c3b2e7..5bc3a74b47 100644 --- a/doctr/models/predictor/__init__.py +++ b/doctr/models/predictor/__init__.py @@ -1,6 +1,6 @@ -from doctr.file_utils import is_tf_available +from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): - from .tensorflow import * -else: - from .pytorch import * # type: ignore[assignment] +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): + from .tensorflow import * # type: ignore[assignment] diff --git a/doctr/models/predictor/base.py b/doctr/models/predictor/base.py index 530590bc61..3737ad3be0 100644 --- a/doctr/models/predictor/base.py +++ b/doctr/models/predictor/base.py @@ -21,7 +21,6 @@ class _OCRPredictor: """Implements an object able to localize and identify text elements in a set of documents Args: - ---- assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages without rotated textual elements. straighten_pages: if True, estimates the page general orientation based on the median line orientation. @@ -194,7 +193,6 @@ def add_hook(self, hook: Callable) -> None: """Add a hook to the predictor Args: - ---- hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds` """ self.hooks.append(hook) diff --git a/doctr/models/predictor/pytorch.py b/doctr/models/predictor/pytorch.py index b47a71449d..f9a3d47097 100644 --- a/doctr/models/predictor/pytorch.py +++ b/doctr/models/predictor/pytorch.py @@ -24,7 +24,6 @@ class OCRPredictor(nn.Module, _OCRPredictor): """Implements an object able to localize and identify text elements in a set of documents Args: - ---- det_predictor: detection module reco_predictor: recognition module assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages @@ -52,8 +51,8 @@ def __init__( **kwargs: Any, ) -> None: nn.Module.__init__(self) - self.det_predictor = det_predictor.eval() # type: ignore[attr-defined] - self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined] + self.det_predictor = det_predictor.eval() + self.reco_predictor = reco_predictor.eval() _OCRPredictor.__init__( self, assume_straight_pages, diff --git a/doctr/models/predictor/tensorflow.py b/doctr/models/predictor/tensorflow.py index 1392943bc4..07f12210ce 100644 --- a/doctr/models/predictor/tensorflow.py +++ b/doctr/models/predictor/tensorflow.py @@ -24,7 +24,6 @@ class OCRPredictor(NestedObject, _OCRPredictor): """Implements an object able to localize and identify text elements in a set of documents Args: - ---- det_predictor: detection module reco_predictor: recognition module assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages @@ -101,12 +100,12 @@ def __call__( origin_page_shapes = [page.shape[:2] for page in pages] # forward again to get predictions on straight pages - loc_preds_dict = self.det_predictor(pages, **kwargs) # type: ignore[assignment] + loc_preds_dict = self.det_predictor(pages, **kwargs) assert all(len(loc_pred) == 1 for loc_pred in loc_preds_dict), ( "Detection Model in ocr_predictor should output only one class" ) - loc_preds: List[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict] # type: ignore[union-attr] + loc_preds: List[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict] # Detach objectness scores from loc_preds loc_preds, objectness_scores = detach_scores(loc_preds) diff --git a/doctr/models/preprocessor/__init__.py b/doctr/models/preprocessor/__init__.py index c7110f5669..5bc3a74b47 100644 --- a/doctr/models/preprocessor/__init__.py +++ b/doctr/models/preprocessor/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): - from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): + from .tensorflow import * # type: ignore[assignment] diff --git a/doctr/models/preprocessor/pytorch.py b/doctr/models/preprocessor/pytorch.py index e074fde829..b8e6ee55b5 100644 --- a/doctr/models/preprocessor/pytorch.py +++ b/doctr/models/preprocessor/pytorch.py @@ -22,7 +22,6 @@ class PreProcessor(nn.Module): """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization. Args: - ---- output_size: expected size of each page in format (H, W) batch_size: the size of page batches mean: mean value of the training distribution by channel @@ -48,11 +47,9 @@ def batch_inputs(self, samples: List[torch.Tensor]) -> List[torch.Tensor]: """Gather samples into batches for inference purposes Args: - ---- samples: list of samples of shape (C, H, W) Returns: - ------- list of batched samples (*, C, H, W) """ num_batches = int(math.ceil(len(samples) / self.batch_size)) @@ -86,11 +83,9 @@ def __call__(self, x: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, n """Prepare document data for model forwarding Args: - ---- x: list of images (np.array) or tensors (already resized and batched) Returns: - ------- list of page batches """ # Input type check diff --git a/doctr/models/preprocessor/tensorflow.py b/doctr/models/preprocessor/tensorflow.py index 15f8be5ac3..31ff667286 100644 --- a/doctr/models/preprocessor/tensorflow.py +++ b/doctr/models/preprocessor/tensorflow.py @@ -20,7 +20,6 @@ class PreProcessor(NestedObject): """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization. Args: - ---- output_size: expected size of each page in format (H, W) batch_size: the size of page batches mean: mean value of the training distribution by channel @@ -48,11 +47,9 @@ def batch_inputs(self, samples: List[tf.Tensor]) -> List[tf.Tensor]: """Gather samples into batches for inference purposes Args: - ---- samples: list of samples (tf.Tensor) Returns: - ------- list of batched samples """ num_batches = int(math.ceil(len(samples) / self.batch_size)) @@ -84,11 +81,9 @@ def __call__(self, x: Union[tf.Tensor, np.ndarray, List[Union[tf.Tensor, np.ndar """Prepare document data for model forwarding Args: - ---- x: list of images (np.array) or tensors (already resized and batched) Returns: - ------- list of page batches """ # Input type check diff --git a/doctr/models/recognition/core.py b/doctr/models/recognition/core.py index ab82218cce..9f46bf3f23 100644 --- a/doctr/models/recognition/core.py +++ b/doctr/models/recognition/core.py @@ -27,11 +27,9 @@ def build_target( sequence lengths. Args: - ---- gts: list of ground-truth labels Returns: - ------- A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch) """ encoded = encode_sequences(sequences=gts, vocab=self.vocab, target_size=self.max_length, eos=len(self.vocab)) @@ -43,7 +41,6 @@ class RecognitionPostProcessor(NestedObject): """Abstract class to postprocess the raw output of the model Args: - ---- vocab: string containing the ordered sequence of supported characters """ diff --git a/doctr/models/recognition/crnn/__init__.py b/doctr/models/recognition/crnn/__init__.py index c7110f5669..5bc3a74b47 100644 --- a/doctr/models/recognition/crnn/__init__.py +++ b/doctr/models/recognition/crnn/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): - from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): + from .tensorflow import * # type: ignore[assignment] diff --git a/doctr/models/recognition/crnn/pytorch.py b/doctr/models/recognition/crnn/pytorch.py index 4c4b891f9a..18617fb36e 100644 --- a/doctr/models/recognition/crnn/pytorch.py +++ b/doctr/models/recognition/crnn/pytorch.py @@ -48,7 +48,6 @@ class CTCPostProcessor(RecognitionPostProcessor): """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding Args: - ---- vocab: string containing the ordered sequence of supported characters """ @@ -62,13 +61,11 @@ def ctc_best_path( `_. Args: - ---- logits: model output, shape: N x T x C vocab: vocabulary to use blank: index of blank label Returns: - ------- A list of tuples: (word, confidence) """ # Gather the most confident characters, and assign the smallest conf among those to the sequence prob @@ -87,11 +84,9 @@ def __call__(self, logits: torch.Tensor) -> List[Tuple[str, float]]: with label_to_idx mapping dictionnary Args: - ---- logits: raw output of the model, shape (N, C + 1, seq_len) Returns: - ------- A tuple of 2 lists: a list of str (words) and a list of float (probs) """ @@ -104,7 +99,6 @@ class CRNN(RecognitionModel, nn.Module): Sequence Recognition and Its Application to Scene Text Recognition" `_. Args: - ---- feature_extractor: the backbone serving as feature extractor vocab: vocabulary used for encoding rnn_units: number of units in the LSTM layers @@ -168,12 +162,10 @@ def compute_loss( """Compute CTC loss for the model. Args: - ---- model_output: predicted logits of the model target: list of target strings Returns: - ------- The loss of the model on the batch """ gt, seq_len = self.build_target(target) @@ -272,12 +264,10 @@ def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset **kwargs: keyword arguments of the CRNN architecture Returns: - ------- text recognition architecture """ return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, ignore_keys=["linear.weight", "linear.bias"], **kwargs) @@ -294,12 +284,10 @@ def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset **kwargs: keyword arguments of the CRNN architecture Returns: - ------- text recognition architecture """ return _crnn( @@ -322,12 +310,10 @@ def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset **kwargs: keyword arguments of the CRNN architecture Returns: - ------- text recognition architecture """ return _crnn( diff --git a/doctr/models/recognition/crnn/tensorflow.py b/doctr/models/recognition/crnn/tensorflow.py index 9f74882673..4bf6915b69 100644 --- a/doctr/models/recognition/crnn/tensorflow.py +++ b/doctr/models/recognition/crnn/tensorflow.py @@ -47,7 +47,6 @@ class CTCPostProcessor(RecognitionPostProcessor): """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding Args: - ---- vocab: string containing the ordered sequence of supported characters ignore_case: if True, ignore case of letters ignore_accents: if True, ignore accents of letters @@ -63,13 +62,11 @@ def __call__( with label_to_idx mapping dictionnary Args: - ---- logits: raw output of the model, shape BATCH_SIZE X SEQ_LEN X NUM_CLASSES + 1 beam_width: An int scalar >= 0 (beam search beam width). top_paths: An int scalar >= 0, <= beam_width (controls output size). Returns: - ------- A list of decoded words of length BATCH_SIZE @@ -114,7 +111,6 @@ class CRNN(RecognitionModel, Model): Sequence Recognition and Its Application to Scene Text Recognition" `_. Args: - ---- feature_extractor: the backbone serving as feature extractor vocab: vocabulary used for encoding rnn_units: number of units in the LSTM layers @@ -166,12 +162,10 @@ def compute_loss( """Compute CTC loss for the model. Args: - ---- model_output: predicted logits of the model target: lengths of each gt word inside the batch Returns: - ------- The loss of the model on the batch """ gt, seq_len = self.build_target(target) @@ -265,12 +259,10 @@ def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset **kwargs: keyword arguments of the CRNN architecture Returns: - ------- text recognition architecture """ return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, **kwargs) @@ -287,12 +279,10 @@ def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset **kwargs: keyword arguments of the CRNN architecture Returns: - ------- text recognition architecture """ return _crnn("crnn_mobilenet_v3_small", pretrained, mobilenet_v3_small_r, **kwargs) @@ -309,12 +299,10 @@ def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset **kwargs: keyword arguments of the CRNN architecture Returns: - ------- text recognition architecture """ return _crnn("crnn_mobilenet_v3_large", pretrained, mobilenet_v3_large_r, **kwargs) diff --git a/doctr/models/recognition/master/__init__.py b/doctr/models/recognition/master/__init__.py index c7110f5669..38ab32543e 100644 --- a/doctr/models/recognition/master/__init__.py +++ b/doctr/models/recognition/master/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/recognition/master/base.py b/doctr/models/recognition/master/base.py index 4d3002893e..706d91fbfe 100644 --- a/doctr/models/recognition/master/base.py +++ b/doctr/models/recognition/master/base.py @@ -23,11 +23,9 @@ def build_target( sequence lengths. Args: - ---- gts: list of ground-truth labels Returns: - ------- A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch) """ encoded = encode_sequences( @@ -46,7 +44,6 @@ class _MASTERPostProcessor(RecognitionPostProcessor): """Abstract class to postprocess the raw output of the model Args: - ---- vocab: string containing the ordered sequence of supported characters """ diff --git a/doctr/models/recognition/master/pytorch.py b/doctr/models/recognition/master/pytorch.py index 875fcbd687..d44139ab32 100644 --- a/doctr/models/recognition/master/pytorch.py +++ b/doctr/models/recognition/master/pytorch.py @@ -37,7 +37,6 @@ class MASTER(_MASTER, nn.Module): Implementation based on the official Pytorch implementation: `_. Args: - ---- feature_extractor: the backbone serving as feature extractor vocab: vocabulary, (without EOS, SOS, PAD) d_model: d parameter for the transformer decoder @@ -130,13 +129,11 @@ def compute_loss( Sequences are masked after the EOS character. Args: - ---- gt: the encoded tensor with gt labels model_output: predicted logits of the model seq_len: lengths of each gt word inside the batch Returns: - ------- The loss of the model on the batch """ # Input length : number of timesteps @@ -163,14 +160,12 @@ def forward( """Call function for training Args: - ---- x: images target: list of str labels return_model_output: if True, return logits return_preds: if True, decode logits Returns: - ------- A dictionnary containing eventually loss, logits and predictions. """ # Encode @@ -221,11 +216,9 @@ def decode(self, encoded: torch.Tensor) -> torch.Tensor: """Decode function for prediction Args: - ---- encoded: input tensor Returns: - ------- A Tuple of torch.Tensor: predictions, logits """ b = encoded.size(0) @@ -316,12 +309,10 @@ def master(pretrained: bool = False, **kwargs: Any) -> MASTER: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset **kwargs: keywoard arguments passed to the MASTER architecture Returns: - ------- text recognition architecture """ return _master( diff --git a/doctr/models/recognition/master/tensorflow.py b/doctr/models/recognition/master/tensorflow.py index e01c089012..62bc3eea98 100644 --- a/doctr/models/recognition/master/tensorflow.py +++ b/doctr/models/recognition/master/tensorflow.py @@ -35,7 +35,6 @@ class MASTER(_MASTER, Model): Implementation based on the official TF implementation: `_. Args: - ---- feature_extractor: the backbone serving as feature extractor vocab: vocabulary, (without EOS, SOS, PAD) d_model: d parameter for the transformer decoder @@ -115,13 +114,11 @@ def compute_loss( Sequences are masked after the EOS character. Args: - ---- gt: the encoded tensor with gt labels model_output: predicted logits of the model seq_len: lengths of each gt word inside the batch Returns: - ------- The loss of the model on the batch """ # Input length : number of timesteps @@ -152,7 +149,6 @@ def call( """Call function for training Args: - ---- x: images target: list of str labels return_model_output: if True, return logits @@ -160,7 +156,6 @@ def call( **kwargs: keyword arguments passed to the decoder Returns: - ------- A dictionnary containing eventually loss, logits and predictions. """ # Encode @@ -209,12 +204,10 @@ def decode(self, encoded: tf.Tensor, **kwargs: Any) -> tf.Tensor: """Decode function for prediction Args: - ---- encoded: encoded features **kwargs: keyword arguments passed to the decoder Returns: - ------- A Tuple of tf.Tensor: predictions, logits """ b = encoded.shape[0] @@ -247,7 +240,6 @@ class MASTERPostProcessor(_MASTERPostProcessor): """Post processor for MASTER architectures Args: - ---- vocab: string containing the ordered sequence of supported characters """ @@ -312,12 +304,10 @@ def master(pretrained: bool = False, **kwargs: Any) -> MASTER: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset **kwargs: keywoard arguments passed to the MASTER architecture Returns: - ------- text recognition architecture """ return _master("master", pretrained, magc_resnet31, **kwargs) diff --git a/doctr/models/recognition/parseq/__init__.py b/doctr/models/recognition/parseq/__init__.py index c7110f5669..38ab32543e 100644 --- a/doctr/models/recognition/parseq/__init__.py +++ b/doctr/models/recognition/parseq/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/recognition/parseq/base.py b/doctr/models/recognition/parseq/base.py index 60aa1fcfcf..4649bbaf9c 100644 --- a/doctr/models/recognition/parseq/base.py +++ b/doctr/models/recognition/parseq/base.py @@ -23,11 +23,9 @@ def build_target( sequence lengths. Args: - ---- gts: list of ground-truth labels Returns: - ------- A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch) """ encoded = encode_sequences( @@ -46,7 +44,6 @@ class _PARSeqPostProcessor(RecognitionPostProcessor): """Abstract class to postprocess the raw output of the model Args: - ---- vocab: string containing the ordered sequence of supported characters """ diff --git a/doctr/models/recognition/parseq/pytorch.py b/doctr/models/recognition/parseq/pytorch.py index 8fff062da9..8ff24f67f2 100644 --- a/doctr/models/recognition/parseq/pytorch.py +++ b/doctr/models/recognition/parseq/pytorch.py @@ -38,7 +38,6 @@ class CharEmbedding(nn.Module): """Implements the character embedding module Args: - ---- vocab_size: size of the vocabulary d_model: dimension of the model """ @@ -56,7 +55,6 @@ class PARSeqDecoder(nn.Module): """Implements decoder module of the PARSeq model Args: - ---- d_model: dimension of the model num_heads: number of attention heads ffd: dimension of the feed forward layer @@ -112,7 +110,6 @@ class PARSeq(_PARSeq, nn.Module): Slightly modified implementation based on the official Pytorch implementation: PARSeq: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset **kwargs: keyword arguments of the PARSeq architecture Returns: - ------- text recognition architecture """ return _parseq( diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py index 0a72d33fc8..2ed07249c9 100644 --- a/doctr/models/recognition/parseq/tensorflow.py +++ b/doctr/models/recognition/parseq/tensorflow.py @@ -36,7 +36,7 @@ class CharEmbedding(layers.Layer): """Implements the character embedding module Args: - ---- + - vocab_size: size of the vocabulary d_model: dimension of the model """ @@ -54,7 +54,6 @@ class PARSeqDecoder(layers.Layer): """Implements decoder module of the PARSeq model Args: - ---- d_model: dimension of the model num_heads: number of attention heads ffd: dimension of the feed forward layer @@ -115,7 +114,6 @@ class PARSeq(_PARSeq, Model): Modified implementation based on the official Pytorch implementation: PARSeq: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset **kwargs: keyword arguments of the PARSeq architecture Returns: - ------- text recognition architecture """ return _parseq( diff --git a/doctr/models/recognition/predictor/__init__.py b/doctr/models/recognition/predictor/__init__.py index ff30c3b2e7..5bc3a74b47 100644 --- a/doctr/models/recognition/predictor/__init__.py +++ b/doctr/models/recognition/predictor/__init__.py @@ -1,6 +1,6 @@ -from doctr.file_utils import is_tf_available +from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): - from .tensorflow import * -else: - from .pytorch import * # type: ignore[assignment] +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): + from .tensorflow import * # type: ignore[assignment] diff --git a/doctr/models/recognition/predictor/_utils.py b/doctr/models/recognition/predictor/_utils.py index ac98d41862..6618cec677 100644 --- a/doctr/models/recognition/predictor/_utils.py +++ b/doctr/models/recognition/predictor/_utils.py @@ -22,7 +22,6 @@ def split_crops( """Chunk crops horizontally to match a given aspect ratio Args: - ---- crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise max_ratio: the maximum aspect ratio that won't trigger the chunk target_ratio: when crops are chunked, they will be chunked to match this aspect ratio @@ -30,7 +29,6 @@ def split_crops( channels_last: whether the numpy array has dimensions in channels last order Returns: - ------- a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required """ _remap_required = False diff --git a/doctr/models/recognition/predictor/pytorch.py b/doctr/models/recognition/predictor/pytorch.py index b71202f7c2..dc1f644750 100644 --- a/doctr/models/recognition/predictor/pytorch.py +++ b/doctr/models/recognition/predictor/pytorch.py @@ -21,7 +21,6 @@ class RecognitionPredictor(nn.Module): """Implements an object able to identify character sequences in images Args: - ---- pre_processor: transform inputs for easier batched model inference model: core detection architecture split_wide_crops: wether to use crop splitting for high aspect ratio crops @@ -67,7 +66,7 @@ def forward( crops = new_crops # Resize & batch them - processed_batches = self.pre_processor(crops) + processed_batches = self.pre_processor(crops) # type: ignore[arg-type] # Forward it _params = next(self.model.parameters()) diff --git a/doctr/models/recognition/predictor/tensorflow.py b/doctr/models/recognition/predictor/tensorflow.py index 409f39323a..84772b7a23 100644 --- a/doctr/models/recognition/predictor/tensorflow.py +++ b/doctr/models/recognition/predictor/tensorflow.py @@ -21,7 +21,6 @@ class RecognitionPredictor(NestedObject): """Implements an object able to identify character sequences in images Args: - ---- pre_processor: transform inputs for easier batched model inference model: core detection architecture split_wide_crops: wether to use crop splitting for high aspect ratio crops diff --git a/doctr/models/recognition/sar/__init__.py b/doctr/models/recognition/sar/__init__.py index c7110f5669..5bc3a74b47 100644 --- a/doctr/models/recognition/sar/__init__.py +++ b/doctr/models/recognition/sar/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): - from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): + from .tensorflow import * # type: ignore[assignment] diff --git a/doctr/models/recognition/sar/pytorch.py b/doctr/models/recognition/sar/pytorch.py index a66bd32036..69f58a1a5d 100644 --- a/doctr/models/recognition/sar/pytorch.py +++ b/doctr/models/recognition/sar/pytorch.py @@ -80,7 +80,6 @@ class SARDecoder(nn.Module): """Implements decoder module of the SAR model Args: - ---- rnn_units: number of hidden units in recurrent cells max_length: maximum length of a sequence vocab_size: number of classes in the model alphabet @@ -166,7 +165,6 @@ class SAR(nn.Module, RecognitionModel): Irregular Text Recognition" `_. Args: - ---- feature_extractor: the backbone serving as feature extractor vocab: vocabulary used for encoding rnn_units: number of hidden units in both encoder and decoder LSTM @@ -281,13 +279,11 @@ def compute_loss( Sequences are masked after the EOS character. Args: - ---- model_output: predicted logits of the model gt: the encoded tensor with gt labels seq_len: lengths of each gt word inside the batch Returns: - ------- The loss of the model on the batch """ # Input length : number of timesteps @@ -308,7 +304,6 @@ class SARPostProcessor(RecognitionPostProcessor): """Post processor for SAR architectures Args: - ---- vocab: string containing the ordered sequence of supported characters """ @@ -379,12 +374,10 @@ def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset **kwargs: keyword arguments of the SAR architecture Returns: - ------- text recognition architecture """ return _sar( diff --git a/doctr/models/recognition/sar/tensorflow.py b/doctr/models/recognition/sar/tensorflow.py index bcb0b207ef..3ae1b9fadf 100644 --- a/doctr/models/recognition/sar/tensorflow.py +++ b/doctr/models/recognition/sar/tensorflow.py @@ -33,7 +33,6 @@ class SAREncoder(layers.Layer, NestedObject): """Implements encoder module of the SAR model Args: - ---- rnn_units: number of hidden rnn units dropout_prob: dropout probability """ @@ -58,7 +57,6 @@ class AttentionModule(layers.Layer, NestedObject): """Implements attention module of the SAR model Args: - ---- attention_units: number of hidden attention units """ @@ -120,7 +118,6 @@ class SARDecoder(layers.Layer, NestedObject): """Implements decoder module of the SAR model Args: - ---- rnn_units: number of hidden units in recurrent cells max_length: maximum length of a sequence vocab_size: number of classes in the model alphabet @@ -210,7 +207,6 @@ class SAR(Model, RecognitionModel): Irregular Text Recognition" `_. Args: - ---- feature_extractor: the backbone serving as feature extractor vocab: vocabulary used for encoding rnn_units: number of hidden units in both encoder and decoder LSTM @@ -269,13 +265,11 @@ def compute_loss( Sequences are masked after the EOS character. Args: - ---- gt: the encoded tensor with gt labels model_output: predicted logits of the model seq_len: lengths of each gt word inside the batch Returns: - ------- The loss of the model on the batch """ # Input length : number of timesteps @@ -340,7 +334,6 @@ class SARPostProcessor(RecognitionPostProcessor): """Post processor for SAR architectures Args: - ---- vocab: string containing the ordered sequence of supported characters """ @@ -414,12 +407,10 @@ def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset **kwargs: keyword arguments of the SAR architecture Returns: - ------- text recognition architecture """ return _sar("sar_resnet31", pretrained, resnet31, **kwargs) diff --git a/doctr/models/recognition/utils.py b/doctr/models/recognition/utils.py index 09de8b9165..b0d22c8dbf 100644 --- a/doctr/models/recognition/utils.py +++ b/doctr/models/recognition/utils.py @@ -14,14 +14,12 @@ def merge_strings(a: str, b: str, dil_factor: float) -> str: """Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters. Args: - ---- a: first char seq, suffix should be similar to b's prefix. b: second char seq, prefix should be similar to a's suffix. dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is only used when the mother sequence is splitted on a character repetition Returns: - ------- A merged character sequence. Example:: @@ -65,13 +63,11 @@ def merge_multi_strings(seq_list: List[str], dil_factor: float) -> str: """Recursively merges consecutive string sequences with overlapping characters. Args: - ---- seq_list: list of sequences to merge. Sequences need to be ordered from left to right. dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is only used when the mother sequence is splitted on a character repetition Returns: - ------- A merged character sequence Example:: diff --git a/doctr/models/recognition/vitstr/__init__.py b/doctr/models/recognition/vitstr/__init__.py index c7110f5669..5bc3a74b47 100644 --- a/doctr/models/recognition/vitstr/__init__.py +++ b/doctr/models/recognition/vitstr/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): - from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): + from .tensorflow import * # type: ignore[assignment] diff --git a/doctr/models/recognition/vitstr/base.py b/doctr/models/recognition/vitstr/base.py index af01dce600..3fc9a9832e 100644 --- a/doctr/models/recognition/vitstr/base.py +++ b/doctr/models/recognition/vitstr/base.py @@ -23,11 +23,9 @@ def build_target( sequence lengths. Args: - ---- gts: list of ground-truth labels Returns: - ------- A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch) """ encoded = encode_sequences( @@ -45,7 +43,6 @@ class _ViTSTRPostProcessor(RecognitionPostProcessor): """Abstract class to postprocess the raw output of the model Args: - ---- vocab: string containing the ordered sequence of supported characters """ diff --git a/doctr/models/recognition/vitstr/pytorch.py b/doctr/models/recognition/vitstr/pytorch.py index ff6644220d..1cc8a619b2 100644 --- a/doctr/models/recognition/vitstr/pytorch.py +++ b/doctr/models/recognition/vitstr/pytorch.py @@ -42,7 +42,6 @@ class ViTSTR(_ViTSTR, nn.Module): Efficient Scene Text Recognition" `_. Args: - ---- feature_extractor: the backbone serving as feature extractor vocab: vocabulary used for encoding embedding_units: number of embedding units @@ -125,13 +124,11 @@ def compute_loss( Sequences are masked after the EOS character. Args: - ---- model_output: predicted logits of the model gt: the encoded tensor with gt labels seq_len: lengths of each gt word inside the batch Returns: - ------- The loss of the model on the batch """ # Input length : number of steps @@ -153,7 +150,6 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor): """Post processor for ViTSTR architecture Args: - ---- vocab: string containing the ordered sequence of supported characters """ @@ -228,12 +224,10 @@ def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset kwargs: keyword arguments of the ViTSTR architecture Returns: - ------- text recognition architecture """ return _vitstr( @@ -259,12 +253,10 @@ def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset kwargs: keyword arguments of the ViTSTR architecture Returns: - ------- text recognition architecture """ return _vitstr( diff --git a/doctr/models/recognition/vitstr/tensorflow.py b/doctr/models/recognition/vitstr/tensorflow.py index 9b121171f8..b5e23880dd 100644 --- a/doctr/models/recognition/vitstr/tensorflow.py +++ b/doctr/models/recognition/vitstr/tensorflow.py @@ -40,7 +40,6 @@ class ViTSTR(_ViTSTR, Model): Efficient Scene Text Recognition" `_. Args: - ---- feature_extractor: the backbone serving as feature extractor vocab: vocabulary used for encoding embedding_units: number of embedding units @@ -85,13 +84,11 @@ def compute_loss( Sequences are masked after the EOS character. Args: - ---- model_output: predicted logits of the model gt: the encoded tensor with gt labels seq_len: lengths of each gt word inside the batch Returns: - ------- The loss of the model on the batch """ # Input length : number of steps @@ -158,7 +155,6 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor): """Post processor for ViTSTR architecture Args: - ---- vocab: string containing the ordered sequence of supported characters """ @@ -239,12 +235,10 @@ def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset **kwargs: keyword arguments of the ViTSTR architecture Returns: - ------- text recognition architecture """ return _vitstr( @@ -268,12 +262,10 @@ def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR: >>> out = model(input_tensor) Args: - ---- pretrained (bool): If True, returns a model pre-trained on our text recognition dataset **kwargs: keyword arguments of the ViTSTR architecture Returns: - ------- text recognition architecture """ return _vitstr( diff --git a/doctr/models/recognition/zoo.py b/doctr/models/recognition/zoo.py index be6ca4ae44..3108e147a5 100644 --- a/doctr/models/recognition/zoo.py +++ b/doctr/models/recognition/zoo.py @@ -69,7 +69,6 @@ def recognition_predictor( >>> out = model([input_page]) Args: - ---- arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn') pretrained: If True, returns a model pre-trained on our text recognition dataset symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right @@ -77,7 +76,6 @@ def recognition_predictor( **kwargs: optional parameters to be passed to the architecture Returns: - ------- Recognition predictor """ return _predictor(arch=arch, pretrained=pretrained, symmetric_pad=symmetric_pad, batch_size=batch_size, **kwargs) diff --git a/doctr/models/utils/__init__.py b/doctr/models/utils/__init__.py index c7110f5669..5bc3a74b47 100644 --- a/doctr/models/utils/__init__.py +++ b/doctr/models/utils/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): - from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): + from .tensorflow import * # type: ignore[assignment] diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py index 998ccb7cf1..69160c5801 100644 --- a/doctr/models/utils/pytorch.py +++ b/doctr/models/utils/pytorch.py @@ -43,7 +43,6 @@ def load_pretrained_params( >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip") Args: - ---- model: the PyTorch model to be loaded url: URL of the zipped set of parameters hash_prefix: first characters of SHA256 expected hash @@ -84,7 +83,6 @@ def conv_sequence_pt( >>> module = Sequential(conv_sequence(3, 32, True, True, kernel_size=3)) Args: - ---- in_channels: number of input channels out_channels: number of output channels relu: whether ReLU should be used @@ -92,7 +90,6 @@ def conv_sequence_pt( **kwargs: additional arguments to be passed to the convolutional layer Returns: - ------- list of layers """ # No bias before Batch norm @@ -122,14 +119,12 @@ def set_device_and_dtype( >>> model, batches = set_device_and_dtype(model, batches, device="cuda", dtype=torch.float16) Args: - ---- model: the model to be set batches: the batches to be set device: the device to be used dtype: the dtype to be used Returns: - ------- the model and batches set """ return model.to(device=device, dtype=dtype), [batch.to(device=device, dtype=dtype) for batch in batches] @@ -145,14 +140,12 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T >>> export_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32)) Args: - ---- model: the PyTorch model to be exported model_name: the name for the exported model dummy_input: the dummy input to the model kwargs: additional arguments to be passed to torch.onnx.export Returns: - ------- the path to the exported model """ torch.onnx.export( diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py index 792089bda4..490e9add9f 100644 --- a/doctr/models/utils/tensorflow.py +++ b/doctr/models/utils/tensorflow.py @@ -39,7 +39,6 @@ def _build_model(model: Model): """Build a model by calling it once with dummy input Args: - ---- model: the model to be built """ model(tf.zeros((1, *model.cfg["input_shape"])), training=False) @@ -58,7 +57,6 @@ def load_pretrained_params( >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.weights.h5") Args: - ---- model: the keras model to be loaded url: URL of the zipped set of parameters hash_prefix: first characters of SHA256 expected hash @@ -88,7 +86,6 @@ def conv_sequence( >>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3])) Args: - ---- out_channels: number of output channels activation: activation to be used (default: no activation) bn: should a batch normalization layer be added @@ -97,7 +94,6 @@ def conv_sequence( **kwargs: additional arguments to be passed to the convolutional layer Returns: - ------- list of layers """ # No bias before Batch norm @@ -125,7 +121,6 @@ class IntermediateLayerGetter(Model): >>> feat_extractor = IntermediateLayerGetter(ResNet50(include_top=False, pooling=False), target_layers) Args: - ---- model: the model to extract feature maps from layer_names: the list of layers to retrieve the feature map from """ @@ -151,14 +146,12 @@ def export_model_to_onnx( >>> dummy_input=[tf.TensorSpec([None, 32, 32, 3], tf.float32, name="input")]) Args: - ---- model: the keras model to be exported model_name: the name for the exported model dummy_input: the dummy input to the model kwargs: additional arguments to be passed to tf2onnx Returns: - ------- the path to the exported model and a list with the output layer names """ # get the users eager mode diff --git a/doctr/models/zoo.py b/doctr/models/zoo.py index eff5fe14c4..b4d4264155 100644 --- a/doctr/models/zoo.py +++ b/doctr/models/zoo.py @@ -83,7 +83,6 @@ def ocr_predictor( >>> out = model([input_page]) Args: - ---- det_arch: name of the detection architecture or the model itself to use (e.g. 'db_resnet50', 'db_mobilenet_v3_large') reco_arch: name of the recognition architecture or the model itself to use @@ -108,7 +107,6 @@ def ocr_predictor( kwargs: keyword args of `OCRPredictor` Returns: - ------- OCR predictor """ return _predictor( @@ -197,7 +195,6 @@ def kie_predictor( >>> out = model([input_page]) Args: - ---- det_arch: name of the detection architecture or the model itself to use (e.g. 'db_resnet50', 'db_mobilenet_v3_large') reco_arch: name of the recognition architecture or the model itself to use @@ -222,7 +219,6 @@ def kie_predictor( kwargs: keyword args of `OCRPredictor` Returns: - ------- KIE predictor """ return _kie_predictor( diff --git a/doctr/transforms/functional/__init__.py b/doctr/transforms/functional/__init__.py index 64556e403a..38ab32543e 100644 --- a/doctr/transforms/functional/__init__.py +++ b/doctr/transforms/functional/__init__.py @@ -1,6 +1,6 @@ from doctr.file_utils import is_tf_available, is_torch_available -if is_tf_available(): - from .tensorflow import * -elif is_torch_available(): +if is_torch_available(): from .pytorch import * +elif is_tf_available(): + from .tensorflow import * diff --git a/doctr/transforms/functional/base.py b/doctr/transforms/functional/base.py index 205e245f8c..b769ac9992 100644 --- a/doctr/transforms/functional/base.py +++ b/doctr/transforms/functional/base.py @@ -20,12 +20,10 @@ def crop_boxes( """Crop localization boxes Args: - ---- boxes: ndarray of shape (N, 4) in relative or abs coordinates crop_box: box (xmin, ymin, xmax, ymax) to crop the image, in the same coord format that the boxes Returns: - ------- the cropped boxes """ is_box_rel = boxes.max() <= 1 @@ -54,12 +52,10 @@ def expand_line(line: np.ndarray, target_shape: Tuple[int, int]) -> Tuple[float, the same direction until we meet one of the edges. Args: - ---- line: array of shape (2, 2) of the point supposed to be on one edge, and the shadow tip. target_shape: the desired mask shape Returns: - ------- 2D coordinates of the first point once we extended the line (on one of the edges) """ if any(coord == 0 or coord == size for coord, size in zip(line[0], target_shape[::-1])): @@ -120,14 +116,12 @@ def create_shadow_mask( """Creates a random shadow mask Args: - ---- target_shape: the target shape (H, W) min_base_width: the relative minimum shadow base width max_tip_width: the relative maximum shadow tip width max_tip_height: the relative maximum shadow tip height Returns: - ------- a numpy ndarray of shape (H, W, 1) with values in the range [0, 1] """ # Default base is top diff --git a/doctr/transforms/functional/pytorch.py b/doctr/transforms/functional/pytorch.py index 740769d99c..19699a8b4a 100644 --- a/doctr/transforms/functional/pytorch.py +++ b/doctr/transforms/functional/pytorch.py @@ -21,12 +21,10 @@ def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor: """Invert the colors of an image Args: - ---- img : torch.Tensor, the image to invert min_val : minimum value of the random shift Returns: - ------- the inverted image """ out = F.rgb_to_grayscale(img, num_output_channels=3) @@ -52,14 +50,12 @@ def rotate_sample( """Rotate image around the center, interpolation=NEAREST, pad with 0 (black) Args: - ---- img: image to rotate geoms: array of geometries of shape (N, 4) or (N, 4, 2) angle: angle in degrees. +: counter-clockwise, -: clockwise expand: whether the image should be padded before the rotation Returns: - ------- A tuple of rotated img (tensor), rotated geometries of shape (N, 4, 2) """ rotated_img = F.rotate(img, angle=angle, fill=0, expand=expand) # Interpolation NEAREST by default @@ -98,13 +94,11 @@ def crop_detection( """Crop and image and associated bboxes Args: - ---- img: image to crop boxes: array of boxes to clip, absolute (int) or relative (float) crop_box: box (xmin, ymin, xmax, ymax) to crop the image. Relative coords. Returns: - ------- A tuple of cropped image, cropped boxes, where the image is not resized. """ if any(val < 0 or val > 1 for val in crop_box): @@ -123,13 +117,11 @@ def random_shadow(img: torch.Tensor, opacity_range: Tuple[float, float], **kwarg """Crop and image and associated bboxes Args: - ---- img: image to modify opacity_range: the minimum and maximum desired opacity of the shadow **kwargs: additional arguments to pass to `create_shadow_mask` Returns: - ------- shaded image """ shadow_mask = create_shadow_mask(img.shape[1:], **kwargs) # type: ignore[arg-type] diff --git a/doctr/transforms/functional/tensorflow.py b/doctr/transforms/functional/tensorflow.py index 1fbc05096e..4cee02e150 100644 --- a/doctr/transforms/functional/tensorflow.py +++ b/doctr/transforms/functional/tensorflow.py @@ -22,12 +22,10 @@ def invert_colors(img: tf.Tensor, min_val: float = 0.6) -> tf.Tensor: """Invert the colors of an image Args: - ---- img : tf.Tensor, the image to invert min_val : minimum value of the random shift Returns: - ------- the inverted image """ out = tf.image.rgb_to_grayscale(img) # Convert to gray @@ -48,13 +46,11 @@ def rotated_img_tensor(img: tf.Tensor, angle: float, expand: bool = False) -> tf """Rotate image around the center, interpolation=NEAREST, pad with 0 (black) Args: - ---- img: image to rotate angle: angle in degrees. +: counter-clockwise, -: clockwise expand: whether the image should be padded before the rotation Returns: - ------- the rotated image (tensor) """ # Compute the expanded padding @@ -107,14 +103,12 @@ def rotate_sample( """Rotate image around the center, interpolation=NEAREST, pad with 0 (black) Args: - ---- img: image to rotate geoms: array of geometries of shape (N, 4) or (N, 4, 2) angle: angle in degrees. +: counter-clockwise, -: clockwise expand: whether the image should be padded before the rotation Returns: - ------- A tuple of rotated img (tensor), rotated boxes (np array) """ # Rotated the image @@ -149,13 +143,11 @@ def crop_detection( """Crop and image and associated bboxes Args: - ---- img: image to crop boxes: array of boxes to clip, absolute (int) or relative (float) crop_box: box (xmin, ymin, xmax, ymax) to crop the image. Relative coords. Returns: - ------- A tuple of cropped image, cropped boxes, where the image is not resized. """ if any(val < 0 or val > 1 for val in crop_box): @@ -181,7 +173,6 @@ def _gaussian_filter( Adapted from: https://github.com/tensorflow/addons/blob/master/tensorflow_addons/image/filters.py Args: - ---- img: image to filter of shape (N, H, W, C) kernel_size: kernel size of the filter sigma: standard deviation of the Gaussian filter @@ -189,7 +180,6 @@ def _gaussian_filter( pad_value: value to pad the image with Returns: - ------- A tensor of shape (N, H, W, C) """ ksize = tf.convert_to_tensor(tf.broadcast_to(kernel_size, [2]), dtype=tf.int32) @@ -239,13 +229,11 @@ def random_shadow(img: tf.Tensor, opacity_range: Tuple[float, float], **kwargs) """Apply a random shadow to a given image Args: - ---- img: image to modify opacity_range: the minimum and maximum desired opacity of the shadow **kwargs: additional arguments to pass to `create_shadow_mask` Returns: - ------- shadowed image """ shadow_mask = create_shadow_mask(img.shape[:2], **kwargs) diff --git a/doctr/transforms/modules/__init__.py b/doctr/transforms/modules/__init__.py index 4053ff5520..8cb106c2f2 100644 --- a/doctr/transforms/modules/__init__.py +++ b/doctr/transforms/modules/__init__.py @@ -2,7 +2,7 @@ from .base import * -if is_tf_available(): - from .tensorflow import * -elif is_torch_available(): - from .pytorch import * # type: ignore[assignment] +if is_torch_available(): + from .pytorch import * +elif is_tf_available(): + from .tensorflow import * # type: ignore[assignment] diff --git a/doctr/transforms/modules/base.py b/doctr/transforms/modules/base.py index 25d15c98ef..b631b31c24 100644 --- a/doctr/transforms/modules/base.py +++ b/doctr/transforms/modules/base.py @@ -21,28 +21,27 @@ class SampleCompose(NestedObject): .. tabs:: - .. tab:: TensorFlow + .. tab:: PyTorch .. code:: python >>> import numpy as np - >>> import tensorflow as tf + >>> import torch >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate - >>> transfo = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)]) - >>> out, out_boxes = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), np.zeros((2, 4))) + >>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)]) + >>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4))) - .. tab:: PyTorch + .. tab:: TensorFlow .. code:: python >>> import numpy as np - >>> import torch + >>> import tensorflow as tf >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate - >>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)]) - >>> out, out_boxes = transfos(torch.rand(8, 64, 64, 3), np.zeros((2, 4))) + >>> transfo = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)]) + >>> out, out_boxes = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), np.zeros((2, 4))) Args: - ---- transforms: list of transformation modules """ @@ -63,26 +62,25 @@ class ImageTransform(NestedObject): .. tabs:: - .. tab:: TensorFlow + .. tab:: PyTorch .. code:: python - >>> import tensorflow as tf + >>> import torch >>> from doctr.transforms import ImageTransform, ColorInversion >>> transfo = ImageTransform(ColorInversion((32, 32))) - >>> out, _ = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), None) + >>> out, _ = transfo(torch.rand(8, 64, 64, 3), None) - .. tab:: PyTorch + .. tab:: TensorFlow .. code:: python - >>> import torch + >>> import tensorflow as tf >>> from doctr.transforms import ImageTransform, ColorInversion >>> transfo = ImageTransform(ColorInversion((32, 32))) - >>> out, _ = transfo(torch.rand(8, 64, 64, 3), None) + >>> out, _ = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), None) Args: - ---- transform: the image transformation module to wrap """ @@ -102,26 +100,25 @@ class ColorInversion(NestedObject): .. tabs:: - .. tab:: TensorFlow + .. tab:: PyTorch .. code:: python - >>> import tensorflow as tf + >>> import torch >>> from doctr.transforms import ColorInversion >>> transfo = ColorInversion(min_val=0.6) - >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) + >>> out = transfo(torch.rand(8, 64, 64, 3)) - .. tab:: PyTorch + .. tab:: TensorFlow .. code:: python - >>> import torch + >>> import tensorflow as tf >>> from doctr.transforms import ColorInversion >>> transfo = ColorInversion(min_val=0.6) - >>> out = transfo(torch.rand(8, 64, 64, 3)) + >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) Args: - ---- min_val: range [min_val, 1] to colorize RGB pixels """ @@ -140,26 +137,25 @@ class OneOf(NestedObject): .. tabs:: - .. tab:: TensorFlow + .. tab:: PyTorch .. code:: python - >>> import tensorflow as tf + >>> import torch >>> from doctr.transforms import OneOf >>> transfo = OneOf([JpegQuality(), Gamma()]) - >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) + >>> out = transfo(torch.rand(1, 64, 64, 3)) - .. tab:: PyTorch + .. tab:: TensorFlow .. code:: python - >>> import torch + >>> import tensorflow as tf >>> from doctr.transforms import OneOf >>> transfo = OneOf([JpegQuality(), Gamma()]) - >>> out = transfo(torch.rand(1, 64, 64, 3)) + >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) Args: - ---- transforms: list of transformations, one only will be picked """ @@ -180,26 +176,25 @@ class RandomApply(NestedObject): .. tabs:: - .. tab:: TensorFlow + .. tab:: PyTorch .. code:: python - >>> import tensorflow as tf + >>> import torch >>> from doctr.transforms import RandomApply >>> transfo = RandomApply(Gamma(), p=.5) - >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) + >>> out = transfo(torch.rand(1, 64, 64, 3)) - .. tab:: PyTorch + .. tab:: TensorFlow .. code:: python - >>> import torch + >>> import tensorflow as tf >>> from doctr.transforms import RandomApply >>> transfo = RandomApply(Gamma(), p=.5) - >>> out = transfo(torch.rand(1, 64, 64, 3)) + >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) Args: - ---- transform: transformation to apply p: probability to apply """ @@ -224,9 +219,7 @@ class RandomRotate(NestedObject): :align: center Args: - ---- - max_angle: maximum angle for rotation, in degrees. Angles will be uniformly picked in - [-max_angle, max_angle] + max_angle: maximum angle for rotation, in degrees. Angles will be uniformly picked in [-max_angle, max_angle] expand: whether the image should be padded before the rotation """ @@ -249,7 +242,6 @@ class RandomCrop(NestedObject): """Randomly crop a tensor image and its boxes Args: - ---- scale: tuple of floats, relative (min_area, max_area) of the crop ratio: tuple of float, relative (min_ratio, max_ratio) where ratio = h/w """ diff --git a/doctr/transforms/modules/pytorch.py b/doctr/transforms/modules/pytorch.py index 639b27e2cf..c7181719ea 100644 --- a/doctr/transforms/modules/pytorch.py +++ b/doctr/transforms/modules/pytorch.py @@ -122,7 +122,6 @@ class GaussianNoise(torch.nn.Module): >>> out = transfo(torch.rand((3, 224, 224))) Args: - ---- mean : mean of the gaussian distribution std : std of the gaussian distribution """ @@ -183,7 +182,6 @@ class RandomShadow(torch.nn.Module): >>> out = transfo(torch.rand((3, 64, 64))) Args: - ---- opacity_range : minimum and maximum opacity of the shade """ @@ -225,12 +223,11 @@ class RandomResize(torch.nn.Module): >>> out = transfo(torch.rand((3, 64, 64))) Args: - ---- scale_range: range of the resizing factor for width and height (independently) preserve_aspect_ratio: whether to preserve the aspect ratio of the image, - given a float value, the aspect ratio will be preserved with this probability + given a float value, the aspect ratio will be preserved with this probability symmetric_pad: whether to symmetrically pad the image, - given a float value, the symmetric padding will be applied with this probability + given a float value, the symmetric padding will be applied with this probability p: probability to apply the transformation """ diff --git a/doctr/transforms/modules/tensorflow.py b/doctr/transforms/modules/tensorflow.py index 2f2fb25f9c..b2c6c532c7 100644 --- a/doctr/transforms/modules/tensorflow.py +++ b/doctr/transforms/modules/tensorflow.py @@ -43,7 +43,6 @@ class Compose(NestedObject): >>> out = transfos(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) Args: - ---- transforms: list of transformation modules """ @@ -68,7 +67,6 @@ class Resize(NestedObject): >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) Args: - ---- output_size: expected output size method: interpolation method preserve_aspect_ratio: if `True`, preserve aspect ratio and pad the rest with zeros @@ -164,7 +162,6 @@ class Normalize(NestedObject): >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) Args: - ---- mean: average value per channel std: standard deviation per channel """ @@ -191,7 +188,6 @@ class LambdaTransformation(NestedObject): >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) Args: - ---- fn: the function to be applied to the input tensor """ @@ -229,7 +225,6 @@ class RandomBrightness(NestedObject): >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) Args: - ---- max_delta: offset to add to each pixel is randomly picked in [-max_delta, max_delta] p: probability to apply transformation """ @@ -254,7 +249,6 @@ class RandomContrast(NestedObject): >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) Args: - ---- delta: multiplicative factor is picked in [1-delta, 1+delta] (reduce contrast if factor<1) """ @@ -278,7 +272,6 @@ class RandomSaturation(NestedObject): >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) Args: - ---- delta: multiplicative factor is picked in [1-delta, 1+delta] (reduce saturation if factor<1) """ @@ -301,7 +294,6 @@ class RandomHue(NestedObject): >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) Args: - ---- max_delta: offset to add to each pixel is randomly picked in [-max_delta, max_delta] """ @@ -324,7 +316,6 @@ class RandomGamma(NestedObject): >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) Args: - ---- min_gamma: non-negative real number, lower bound for gamma param max_gamma: non-negative real number, upper bound for gamma min_gain: lower bound for constant multiplier @@ -362,7 +353,6 @@ class RandomJpegQuality(NestedObject): >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) Args: - ---- min_quality: int between [0, 100] max_quality: int between [0, 100] """ @@ -387,7 +377,6 @@ class GaussianBlur(NestedObject): >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) Args: - ---- kernel_shape: size of the blurring kernel std: min and max value of the standard deviation """ @@ -430,7 +419,6 @@ class GaussianNoise(NestedObject): >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) Args: - ---- mean : mean of the gaussian distribution std : std of the gaussian distribution """ @@ -465,7 +453,6 @@ class RandomHorizontalFlip(NestedObject): >>> out = transfo(image, target) Args: - ---- p : probability of Horizontal Flip """ @@ -495,7 +482,6 @@ class RandomShadow(NestedObject): >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) Args: - ---- opacity_range : minimum and maximum opacity of the shade """ @@ -530,12 +516,11 @@ class RandomResize(NestedObject): >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) Args: - ---- scale_range: range of the resizing factor for width and height (independently) preserve_aspect_ratio: whether to preserve the aspect ratio of the image, - given a float value, the aspect ratio will be preserved with this probability + given a float value, the aspect ratio will be preserved with this probability symmetric_pad: whether to symmetrically pad the image, - given a float value, the symmetric padding will be applied with this probability + given a float value, the symmetric padding will be applied with this probability p: probability to apply the transformation """ diff --git a/doctr/utils/data.py b/doctr/utils/data.py index aca34801bb..9702969267 100644 --- a/doctr/utils/data.py +++ b/doctr/utils/data.py @@ -56,7 +56,6 @@ def download_from_url( >>> download_from_url("https://yoursource.com/yourcheckpoint-yourhash.zip") Args: - ---- url: the URL of the file to download file_name: optional name of the file once downloaded hash_prefix: optional expected SHA256 hash of the file @@ -64,11 +63,9 @@ def download_from_url( cache_subdir: subfolder to use in the cache Returns: - ------- the location of the downloaded file Note: - ---- You can change cache directory location by using `DOCTR_CACHE_DIR` environment variable. """ if not isinstance(file_name, str): diff --git a/doctr/utils/fonts.py b/doctr/utils/fonts.py index ecd3e377db..618ed0b713 100644 --- a/doctr/utils/fonts.py +++ b/doctr/utils/fonts.py @@ -18,12 +18,10 @@ def get_font( """Resolves a compatible ImageFont for the system Args: - ---- font_family: the font family to use font_size: the size of the font upon rendering Returns: - ------- the Pillow font """ # Font selection diff --git a/doctr/utils/geometry.py b/doctr/utils/geometry.py index 653b9f8b9d..b513a6308e 100644 --- a/doctr/utils/geometry.py +++ b/doctr/utils/geometry.py @@ -34,11 +34,9 @@ def bbox_to_polygon(bbox: BoundingBox) -> Polygon4P: """Convert a bounding box to a polygon Args: - ---- bbox: a bounding box Returns: - ------- a polygon """ return bbox[0], (bbox[1][0], bbox[0][1]), (bbox[0][0], bbox[1][1]), bbox[1] @@ -48,11 +46,9 @@ def polygon_to_bbox(polygon: Polygon4P) -> BoundingBox: """Convert a polygon to a bounding box Args: - ---- polygon: a polygon Returns: - ------- a bounding box """ x, y = zip(*polygon) @@ -63,11 +59,9 @@ def detach_scores(boxes: List[np.ndarray]) -> Tuple[List[np.ndarray], List[np.nd """Detach the objectness scores from box predictions Args: - ---- boxes: list of arrays with boxes of shape (N, 5) or (N, 5, 2) Returns: - ------- a tuple of two lists: the first one contains the boxes without the objectness scores, the second one contains the objectness scores """ @@ -85,7 +79,6 @@ def resolve_enclosing_bbox(bboxes: Union[List[BoundingBox], np.ndarray]) -> Unio """Compute enclosing bbox either from: Args: - ---- bboxes: boxes in one of the following formats: - an array of boxes: (*, 4), where boxes have this shape: @@ -94,7 +87,6 @@ def resolve_enclosing_bbox(bboxes: Union[List[BoundingBox], np.ndarray]) -> Unio - a list of BoundingBox Returns: - ------- a (1, 4) array (enclosing boxarray), or a BoundingBox """ if isinstance(bboxes, np.ndarray): @@ -109,7 +101,6 @@ def resolve_enclosing_rbbox(rbboxes: List[np.ndarray], intermed_size: int = 1024 """Compute enclosing rotated bbox either from: Args: - ---- rbboxes: boxes in one of the following formats: - an array of boxes: (*, 4, 2), where boxes have this shape: @@ -119,7 +110,6 @@ def resolve_enclosing_rbbox(rbboxes: List[np.ndarray], intermed_size: int = 1024 intermed_size: size of the intermediate image Returns: - ------- a (4, 2) array (enclosing rotated box) """ cloud: np.ndarray = np.concatenate(rbboxes, axis=0) @@ -133,12 +123,10 @@ def rotate_abs_points(points: np.ndarray, angle: float = 0.0) -> np.ndarray: """Rotate points counter-clockwise. Args: - ---- points: array of size (N, 2) angle: angle between -90 and +90 degrees Returns: - ------- Rotated points """ angle_rad = angle * np.pi / 180.0 # compute radian angle for np functions @@ -152,12 +140,10 @@ def compute_expanded_shape(img_shape: Tuple[int, int], angle: float) -> Tuple[in """Compute the shape of an expanded rotated image Args: - ---- img_shape: the height and width of the image angle: angle between -90 and +90 degrees Returns: - ------- the height and width of the rotated image """ points: np.ndarray = np.array([ @@ -181,14 +167,12 @@ def rotate_abs_geoms( image center. Args: - ---- geoms: (N, 4) or (N, 4, 2) array of ABSOLUTE coordinate boxes angle: anti-clockwise rotation angle in degrees img_shape: the height and width of the image expand: whether the image should be padded to avoid information loss Returns: - ------- A batch of rotated polygons (N, 4, 2) """ # Switch to polygons @@ -220,13 +204,11 @@ def remap_boxes(loc_preds: np.ndarray, orig_shape: Tuple[int, int], dest_shape: coordinates after a resizing of the image. Args: - ---- loc_preds: (N, 4, 2) array of RELATIVE loc_preds orig_shape: shape of the origin image dest_shape: shape of the destination image Returns: - ------- A batch of rotated loc_preds (N, 4, 2) expressed in the destination referencial """ if len(dest_shape) != 2: @@ -255,7 +237,6 @@ def rotate_boxes( is done to remove the padding that is created by rotate_page(expand=True) Args: - ---- loc_preds: (N, 4) or (N, 4, 2) array of RELATIVE boxes angle: angle between -90 and +90 degrees orig_shape: shape of the origin image @@ -263,7 +244,6 @@ def rotate_boxes( target_shape: shape of the destination image Returns: - ------- A batch of rotated boxes (N, 4, 2): or a batch of straight bounding boxes """ # Change format of the boxes to rotated boxes @@ -310,14 +290,12 @@ def rotate_image( """Rotate an image counterclockwise by an given angle. Args: - ---- image: numpy tensor to rotate angle: rotation angle in degrees, between -90 and +90 expand: whether the image should be padded before the rotation preserve_origin_shape: if expand is set to True, resizes the final output to the original image size Returns: - ------- Rotated array, padded by 0 by default. """ # Compute the expanded padding @@ -356,11 +334,9 @@ def remove_image_padding(image: np.ndarray) -> np.ndarray: """Remove black border padding from an image Args: - ---- image: numpy tensor to remove padding from Returns: - ------- Image with padding removed """ # Find the bounding box of the non-black region @@ -394,12 +370,10 @@ def convert_to_relative_coords(geoms: np.ndarray, img_shape: Tuple[int, int]) -> """Convert a geometry to relative coordinates Args: - ---- geoms: a set of polygons of shape (N, 4, 2) or of straight boxes of shape (N, 4) img_shape: the height and width of the image Returns: - ------- the updated geometry """ # Polygon @@ -421,14 +395,12 @@ def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True """Created cropped images from list of bounding boxes Args: - ---- img: input image boxes: bounding boxes of shape (N, 4) where N is the number of boxes, and the relative coordinates (xmin, ymin, xmax, ymax) channels_last: whether the channel dimensions is the last one instead of the last one Returns: - ------- list of cropped images """ if boxes.shape[0] == 0: @@ -457,7 +429,6 @@ def extract_rcrops( """Created cropped images from list of rotated bounding boxes Args: - ---- img: input image polys: bounding boxes of shape (N, 4, 2) dtype: target data type of bounding boxes @@ -465,7 +436,6 @@ def extract_rcrops( assume_horizontal: whether the boxes are assumed to be only horizontally oriented Returns: - ------- list of cropped images """ if polys.shape[0] == 0: diff --git a/doctr/utils/metrics.py b/doctr/utils/metrics.py index 4fe3d59ebe..ddd6f9774e 100644 --- a/doctr/utils/metrics.py +++ b/doctr/utils/metrics.py @@ -25,12 +25,10 @@ def string_match(word1: str, word2: str) -> Tuple[bool, bool, bool, bool]: """Performs string comparison with multiple levels of tolerance Args: - ---- word1: a string word2: another string Returns: - ------- a tuple with booleans specifying respectively whether the raw strings, their lower-case counterparts, their anyascii counterparts and their lower-case anyascii counterparts match """ @@ -84,7 +82,6 @@ def update( """Update the state of the metric with new predictions Args: - ---- gt: list of groung-truth character sequences pred: list of predicted character sequences """ @@ -103,8 +100,7 @@ def update( def summary(self) -> Dict[str, float]: """Computes the aggregated metrics - Returns - ------- + Returns: a dictionary with the exact match score for the raw data, its lower-case counterpart, its anyascii counterpart and its lower-case anyascii counterpart """ @@ -130,12 +126,10 @@ def box_iou(boxes_1: np.ndarray, boxes_2: np.ndarray) -> np.ndarray: """Computes the IoU between two sets of bounding boxes Args: - ---- boxes_1: bounding boxes of shape (N, 4) in format (xmin, ymin, xmax, ymax) boxes_2: bounding boxes of shape (M, 4) in format (xmin, ymin, xmax, ymax) Returns: - ------- the IoU matrix of shape (N, M) """ iou_mat: np.ndarray = np.zeros((boxes_1.shape[0], boxes_2.shape[0]), dtype=np.float32) @@ -160,14 +154,12 @@ def polygon_iou(polys_1: np.ndarray, polys_2: np.ndarray) -> np.ndarray: """Computes the IoU between two sets of rotated bounding boxes Args: - ---- polys_1: rotated bounding boxes of shape (N, 4, 2) polys_2: rotated bounding boxes of shape (M, 4, 2) mask_shape: spatial shape of the intermediate masks use_broadcasting: if set to True, leverage broadcasting speedup by consuming more memory Returns: - ------- the IoU matrix of shape (N, M) """ if polys_1.ndim != 3 or polys_2.ndim != 3: @@ -191,12 +183,10 @@ def nms(boxes: np.ndarray, thresh: float = 0.5) -> List[int]: """Perform non-max suppression, borrowed from `_. Args: - ---- boxes: np array of straight boxes: (*, 5), (xmin, ymin, xmax, ymax, score) thresh: iou threshold to perform box suppression. Returns: - ------- A list of box indexes to keep """ x1 = boxes[:, 0] @@ -260,7 +250,6 @@ class LocalizationConfusion: >>> metric.summary() Args: - ---- iou_thresh: minimum IoU to consider a pair of prediction and ground truth as a match use_polygons: if set to True, predictions and targets will be expected to have rotated format """ @@ -278,7 +267,6 @@ def update(self, gts: np.ndarray, preds: np.ndarray) -> None: """Updates the metric Args: - ---- gts: a set of relative bounding boxes either of shape (N, 4) or (N, 5) if they are rotated ones preds: a set of relative bounding boxes either of shape (M, 4) or (M, 5) if they are rotated ones """ @@ -301,8 +289,7 @@ def update(self, gts: np.ndarray, preds: np.ndarray) -> None: def summary(self) -> Tuple[Optional[float], Optional[float], Optional[float]]: """Computes the aggregated metrics - Returns - ------- + Returns: a tuple with the recall, precision and meanIoU scores """ # Recall @@ -360,7 +347,6 @@ class OCRMetric: >>> metric.summary() Args: - ---- iou_thresh: minimum IoU to consider a pair of prediction and ground truth as a match use_polygons: if set to True, predictions and targets will be expected to have rotated format """ @@ -384,7 +370,6 @@ def update( """Updates the metric Args: - ---- gt_boxes: a set of relative bounding boxes either of shape (N, 4) or (N, 5) if they are rotated ones pred_boxes: a set of relative bounding boxes either of shape (M, 4) or (M, 5) if they are rotated ones gt_labels: a list of N string labels @@ -421,8 +406,7 @@ def update( def summary(self) -> Tuple[Dict[str, Optional[float]], Dict[str, Optional[float]], Optional[float]]: """Computes the aggregated metrics - Returns - ------- + Returns: a tuple with the recall & precision for each string comparison and the mean IoU """ # Recall @@ -493,7 +477,6 @@ class DetectionMetric: >>> metric.summary() Args: - ---- iou_thresh: minimum IoU to consider a pair of prediction and ground truth as a match use_polygons: if set to True, predictions and targets will be expected to have rotated format """ @@ -517,7 +500,6 @@ def update( """Updates the metric Args: - ---- gt_boxes: a set of relative bounding boxes either of shape (N, 4) or (N, 5) if they are rotated ones pred_boxes: a set of relative bounding boxes either of shape (M, 4) or (M, 5) if they are rotated ones gt_labels: an array of class indices of shape (N,) @@ -549,8 +531,7 @@ def update( def summary(self) -> Tuple[Optional[float], Optional[float], Optional[float]]: """Computes the aggregated metrics - Returns - ------- + Returns: a tuple with the recall & precision for each class prediction and the mean IoU """ # Recall diff --git a/doctr/utils/multithreading.py b/doctr/utils/multithreading.py index 6450a0bfd2..f64e1aacc8 100644 --- a/doctr/utils/multithreading.py +++ b/doctr/utils/multithreading.py @@ -22,17 +22,14 @@ def multithread_exec(func: Callable[[Any], Any], seq: Iterable[Any], threads: Op >>> results = multithread_exec(lambda x: x ** 2, entries) Args: - ---- func: function to be executed on each element of the iterable seq: iterable threads: number of workers to be used for multiprocessing Returns: - ------- iterator of the function's results using the iterable as inputs Notes: - ----- This function uses ThreadPool from multiprocessing package, which uses `/dev/shm` directory for shared memory. If you do not have write permissions for this directory (if you run `doctr` on AWS Lambda for instance), you might want to disable multiprocessing. To achieve that, set 'DOCTR_MULTIPROCESSING_DISABLE' to 'TRUE'. diff --git a/doctr/utils/reconstitution.py b/doctr/utils/reconstitution.py index a229e9ddbc..23541f059c 100644 --- a/doctr/utils/reconstitution.py +++ b/doctr/utils/reconstitution.py @@ -121,7 +121,6 @@ def synthesize_page( """Draw a the content of the element page (OCR response) on a blank page. Args: - ---- page: exported Page object to represent draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0 font_family: family of the font @@ -130,7 +129,6 @@ def synthesize_page( max_font_size: maximum font size Returns: - ------- the synthesized page """ # Draw template @@ -181,7 +179,6 @@ def synthesize_kie_page( """Draw a the content of the element page (OCR response) on a blank page. Args: - ---- page: exported Page object to represent draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0 font_family: family of the font @@ -190,7 +187,6 @@ def synthesize_kie_page( max_font_size: maximum font size Returns: - ------- the synthesized page """ # Draw template diff --git a/doctr/utils/visualization.py b/doctr/utils/visualization.py index 4e97f751fe..c0e7b75d04 100644 --- a/doctr/utils/visualization.py +++ b/doctr/utils/visualization.py @@ -30,7 +30,6 @@ def rect_patch( """Create a matplotlib rectangular patch for the element Args: - ---- geometry: bounding box of the element page_dimensions: dimensions of the Page in format (height, width) label: label to display when hovered @@ -41,7 +40,6 @@ def rect_patch( preserve_aspect_ratio: pass True if you passed True to the predictor Returns: - ------- a rectangular Patch """ if len(geometry) != 2 or any(not isinstance(elt, tuple) or len(elt) != 2 for elt in geometry): @@ -81,7 +79,6 @@ def polygon_patch( """Create a matplotlib polygon patch for the element Args: - ---- geometry: bounding box of the element page_dimensions: dimensions of the Page in format (height, width) label: label to display when hovered @@ -92,7 +89,6 @@ def polygon_patch( preserve_aspect_ratio: pass True if you passed True to the predictor Returns: - ------- a polygon Patch """ if not geometry.shape == (4, 2): @@ -121,13 +117,11 @@ def create_obj_patch( """Create a matplotlib patch for the element Args: - ---- geometry: bounding box (straight or rotated) of the element page_dimensions: dimensions of the page in format (height, width) **kwargs: keyword arguments for the patch Returns: - ------- a matplotlib Patch """ if isinstance(geometry, tuple): @@ -144,11 +138,9 @@ def get_colors(num_colors: int) -> List[Tuple[float, float, float]]: """Generate num_colors color for matplotlib Args: - ---- num_colors: number of colors to generate Returns: - ------- colors: list of generated colors """ colors = [] @@ -183,7 +175,6 @@ def visualize_page( >>> plt.show() Args: - ---- page: the exported Page of a Document image: np array of the page, needs to have the same shape than page['dimensions'] words_only: whether only words should be displayed @@ -194,7 +185,6 @@ def visualize_page( **kwargs: keyword arguments for the polygon patch Returns: - ------- the matplotlib figure """ # Get proper scale and aspect ratio @@ -309,7 +299,6 @@ def visualize_kie_page( >>> plt.show() Args: - ---- page: the exported Page of a Document image: np array of the page, needs to have the same shape than page['dimensions'] words_only: whether only words should be displayed @@ -320,7 +309,6 @@ def visualize_kie_page( **kwargs: keyword arguments for the polygon patch Returns: - ------- the matplotlib figure """ # Get proper scale and aspect ratio @@ -367,7 +355,6 @@ def draw_boxes(boxes: np.ndarray, image: np.ndarray, color: Optional[Tuple[int, """Draw an array of relative straight boxes on an image Args: - ---- boxes: array of relative boxes, of shape (*, 4) image: np array, float32 or uint8 color: color to use for bounding box edges diff --git a/pyproject.toml b/pyproject.toml index 763e966187..db09c839f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -192,7 +192,7 @@ select = [ "E", "W", "F", "I", "N", "Q", "C4", "T10", "LOG", "D101", "D103", "D201","D202","D207","D208","D214","D215","D300","D301","D417", "D419", "D207" # pydocstyle ] -ignore = ["E402", "E203", "F403", "E731", "N812", "N817", "C408"] +ignore = ["E402", "E203", "F403", "E731", "N812", "N817", "C408", "LOG015"] [tool.ruff.lint.isort] known-first-party = ["doctr", "app", "utils"] diff --git a/references/classification/utils.py b/references/classification/utils.py index 1b43d0b18f..a2c9ae5460 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -39,7 +39,6 @@ def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> N Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py Args: - ---- lr_recorder: list of LR values loss_recorder: list of loss values beta (float, optional): smoothing factor diff --git a/references/detection/utils.py b/references/detection/utils.py index 265f63ff37..fe1b326f74 100644 --- a/references/detection/utils.py +++ b/references/detection/utils.py @@ -49,7 +49,6 @@ def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> N Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py Args: - ---- lr_recorder: list of LR values loss_recorder: list of loss values beta (float, optional): smoothing factor diff --git a/references/recognition/train_pytorch_ddp.py b/references/recognition/train_pytorch_ddp.py index e7d71a30bf..8a7f1428be 100644 --- a/references/recognition/train_pytorch_ddp.py +++ b/references/recognition/train_pytorch_ddp.py @@ -106,8 +106,8 @@ def evaluate(model, device, val_loader, batch_transforms, val_metric, amp=False) def main(rank: int, world_size: int, args): - """Args: - ---- + """ + Args: rank (int): device id to put the model on world_size (int): number of processes participating in the job args: other arguments passed through the CLI diff --git a/references/recognition/utils.py b/references/recognition/utils.py index 1a50bbe0c3..d4e2c3af88 100644 --- a/references/recognition/utils.py +++ b/references/recognition/utils.py @@ -39,7 +39,6 @@ def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> N Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py. Args: - ---- lr_recorder: list of LR values loss_recorder: list of loss values beta (float, optional): smoothing factor