From 93af653d2eeebe5fa45febe2c77b612fa38f5534 Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 9 Oct 2024 07:31:44 +0200 Subject: [PATCH 1/3] all trainable / loadable --- doctr/file_utils.py | 15 --------------- .../classification/magc_resnet/tensorflow.py | 2 +- .../models/classification/mobilenet/tensorflow.py | 12 ++++++------ doctr/models/classification/resnet/tensorflow.py | 10 +++++----- doctr/models/classification/textnet/tensorflow.py | 6 +++--- doctr/models/classification/vgg/tensorflow.py | 2 +- doctr/models/classification/vit/tensorflow.py | 4 ++-- .../differentiable_binarization/tensorflow.py | 4 ++-- doctr/models/detection/fast/tensorflow.py | 6 +++--- doctr/models/detection/linknet/tensorflow.py | 6 +++--- doctr/models/recognition/crnn/tensorflow.py | 8 ++++---- doctr/models/recognition/master/tensorflow.py | 2 +- doctr/models/recognition/parseq/tensorflow.py | 2 +- doctr/models/recognition/sar/tensorflow.py | 6 ++---- doctr/models/recognition/vitstr/tensorflow.py | 4 ++-- doctr/models/utils/tensorflow.py | 2 +- pyproject.toml | 2 -- references/classification/latency_tensorflow.py | 4 ---- .../classification/train_tensorflow_character.py | 4 ---- .../train_tensorflow_orientation.py | 4 ---- references/detection/evaluate_tensorflow.py | 4 ---- references/detection/latency_tensorflow.py | 4 ---- references/detection/train_tensorflow.py | 4 ---- references/recognition/evaluate_tensorflow.py | 4 ---- references/recognition/latency_tensorflow.py | 4 ---- references/recognition/train_tensorflow.py | 4 ---- tests/tensorflow/test_models_factory.py | 9 +++------ 27 files changed, 40 insertions(+), 98 deletions(-) diff --git a/doctr/file_utils.py b/doctr/file_utils.py index fc1129b0c..a61498c5e 100644 --- a/doctr/file_utils.py +++ b/doctr/file_utils.py @@ -35,20 +35,6 @@ logging.info("Disabling PyTorch because USE_TF is set") _torch_available = False -# Compatibility fix to make sure tensorflow.keras stays at Keras 2 -if "TF_USE_LEGACY_KERAS" not in os.environ: - os.environ["TF_USE_LEGACY_KERAS"] = "1" - -elif os.environ["TF_USE_LEGACY_KERAS"] != "1": - raise ValueError( - "docTR is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. " - ) - - -def ensure_keras_v2() -> None: # pragma: no cover - if not os.environ.get("TF_USE_LEGACY_KERAS") == "1": - os.environ["TF_USE_LEGACY_KERAS"] = "1" - if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: _tf_available = importlib.util.find_spec("tensorflow") is not None @@ -79,7 +65,6 @@ def ensure_keras_v2() -> None: # pragma: no cover _tf_available = False else: logging.info(f"TensorFlow version {_tf_version} available.") - ensure_keras_v2() import tensorflow as tf # Enable eager execution - this is required for some models to work properly diff --git a/doctr/models/classification/magc_resnet/tensorflow.py b/doctr/models/classification/magc_resnet/tensorflow.py index d920ca44a..0aecdff3c 100644 --- a/doctr/models/classification/magc_resnet/tensorflow.py +++ b/doctr/models/classification/magc_resnet/tensorflow.py @@ -26,7 +26,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/magc_resnet31-16aa7d71.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/classification/mobilenet/tensorflow.py b/doctr/models/classification/mobilenet/tensorflow.py index ae3535d94..dec338afe 100644 --- a/doctr/models/classification/mobilenet/tensorflow.py +++ b/doctr/models/classification/mobilenet/tensorflow.py @@ -32,42 +32,42 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large-d857506e.weights.h5&src=0", + "url": None, }, "mobilenet_v3_large_r": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large_r-eef2e3c6.weights.h5&src=0", + "url": None, }, "mobilenet_v3_small": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small-3fcebad7.weights.h5&src=0", + "url": None, }, "mobilenet_v3_small_r": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_r-dd50218d.weights.h5&src=0", + "url": None, }, "mobilenet_v3_small_crop_orientation": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (128, 128, 3), "classes": [0, -90, 180, 90], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_crop_orientation-ef019b6b.weights.h5&src=0", + "url": None, }, "mobilenet_v3_small_page_orientation": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (512, 512, 3), "classes": [0, -90, 180, 90], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_page_orientation-0071d55d.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/classification/resnet/tensorflow.py b/doctr/models/classification/resnet/tensorflow.py index 662a43c3a..21d75a13e 100644 --- a/doctr/models/classification/resnet/tensorflow.py +++ b/doctr/models/classification/resnet/tensorflow.py @@ -24,35 +24,35 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet18-f42d3854.weights.h5&src=0", + "url": None, }, "resnet31": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet31-ab75f78c.weights.h5&src=0", + "url": None, }, "resnet34": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34-03967df9.weights.h5&src=0", + "url": None, }, "resnet50": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet50-82358f34.weights.h5&src=0", + "url": None, }, "resnet34_wide": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34_wide-b18fdf79.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/classification/textnet/tensorflow.py b/doctr/models/classification/textnet/tensorflow.py index e5e6105a7..df6e60bad 100644 --- a/doctr/models/classification/textnet/tensorflow.py +++ b/doctr/models/classification/textnet/tensorflow.py @@ -22,21 +22,21 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_tiny-a29eeb4a.weights.h5&src=0", + "url": None, }, "textnet_small": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_small-1c2df0e3.weights.h5&src=0", + "url": None, }, "textnet_base": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_base-8b4b89bc.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/classification/vgg/tensorflow.py b/doctr/models/classification/vgg/tensorflow.py index c42e369bc..22d57272b 100644 --- a/doctr/models/classification/vgg/tensorflow.py +++ b/doctr/models/classification/vgg/tensorflow.py @@ -22,7 +22,7 @@ "std": (1.0, 1.0, 1.0), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vgg16_bn_r-b4d69212.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/classification/vit/tensorflow.py b/doctr/models/classification/vit/tensorflow.py index 386065bca..26e5a64db 100644 --- a/doctr/models/classification/vit/tensorflow.py +++ b/doctr/models/classification/vit/tensorflow.py @@ -25,14 +25,14 @@ "std": (0.299, 0.296, 0.301), "input_shape": (3, 32, 32), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_s-69bc459e.weights.h5&src=0", + "url": None, }, "vit_b": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_b-c64705bd.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index b0ca1f08e..805ee0bf5 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -34,13 +34,13 @@ "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_resnet50-649fa22b.weights.h5&src=0", + "url": None, }, "db_mobilenet_v3_large": { "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_mobilenet_v3_large-ee2e1dbe.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/detection/fast/tensorflow.py b/doctr/models/detection/fast/tensorflow.py index b0043494e..607b8d25b 100644 --- a/doctr/models/detection/fast/tensorflow.py +++ b/doctr/models/detection/fast/tensorflow.py @@ -28,19 +28,19 @@ "input_shape": (1024, 1024, 3), "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_tiny-d7379d7b.weights.h5&src=0", + "url": None, }, "fast_small": { "input_shape": (1024, 1024, 3), "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_small-44b27eb6.weights.h5&src=0", + "url": None, }, "fast_base": { "input_shape": (1024, 1024, 3), "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_base-f2c6c736.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index 9c991c6f4..667508ef9 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -32,19 +32,19 @@ "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet18-615a82c5.weights.h5&src=0", + "url": None, }, "linknet_resnet34": { "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet34-9d772be5.weights.h5&src=0", + "url": None, }, "linknet_resnet50": { "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet50-6bf6c8b5.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/recognition/crnn/tensorflow.py b/doctr/models/recognition/crnn/tensorflow.py index 9f7488267..c73e5bf36 100644 --- a/doctr/models/recognition/crnn/tensorflow.py +++ b/doctr/models/recognition/crnn/tensorflow.py @@ -23,22 +23,22 @@ "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), - "vocab": VOCABS["legacy_french"], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_vgg16_bn-9c188f45.weights.h5&src=0", + "vocab": VOCABS["french"], + "url": None, }, "crnn_mobilenet_v3_small": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_small-54850265.weights.h5&src=0", + "url": None, }, "crnn_mobilenet_v3_large": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_large-c64045e5.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/recognition/master/tensorflow.py b/doctr/models/recognition/master/tensorflow.py index e01c08901..8574a6f6d 100644 --- a/doctr/models/recognition/master/tensorflow.py +++ b/doctr/models/recognition/master/tensorflow.py @@ -25,7 +25,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/master-d7fdaeff.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py index d8c54527b..80bcb8550 100644 --- a/doctr/models/recognition/parseq/tensorflow.py +++ b/doctr/models/recognition/parseq/tensorflow.py @@ -27,7 +27,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/parseq-4152a87e.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/recognition/sar/tensorflow.py b/doctr/models/recognition/sar/tensorflow.py index bcb0b207e..1e3c6f7a2 100644 --- a/doctr/models/recognition/sar/tensorflow.py +++ b/doctr/models/recognition/sar/tensorflow.py @@ -24,7 +24,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/sar_resnet31-5a58806c.weights.h5&src=0", + "url": None, }, } @@ -170,9 +170,7 @@ def call( for t in range(self.max_length + 1): # 32 if t == 0: # step to init the first states of the LSTMCell - states = self.lstm_cells.get_initial_state( - inputs=None, batch_size=features.shape[0], dtype=features.dtype - ) + states = self.lstm_cells.get_initial_state(batch_size=features.shape[0]) prev_symbol = holistic elif t == 1: # step to init a 'blank' sequence of length vocab_size + 1 filled with zeros diff --git a/doctr/models/recognition/vitstr/tensorflow.py b/doctr/models/recognition/vitstr/tensorflow.py index 9b121171f..222f056d8 100644 --- a/doctr/models/recognition/vitstr/tensorflow.py +++ b/doctr/models/recognition/vitstr/tensorflow.py @@ -23,14 +23,14 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_small-d28b8d92.weights.h5&src=0", + "url": None, }, "vitstr_base": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_base-9ad6eb84.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py index c04a4b289..6c2137cd4 100644 --- a/doctr/models/utils/tensorflow.py +++ b/doctr/models/utils/tensorflow.py @@ -131,7 +131,7 @@ class IntermediateLayerGetter(Model): """ def __init__(self, model: Model, layer_names: List[str]) -> None: - intermediate_fmaps = [model.get_layer(layer_name).get_output_at(0) for layer_name in layer_names] + intermediate_fmaps = [model.get_layer(layer_name)._inbound_nodes[0].outputs[0] for layer_name in layer_names] super().__init__(model.input, outputs=intermediate_fmaps) def __repr__(self) -> str: diff --git a/pyproject.toml b/pyproject.toml index 613eb512e..9836b3e17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,6 @@ dependencies = [ tf = [ # cf. https://github.com/mindee/doctr/pull/1461 "tensorflow>=2.15.0,<3.0.0", - "tf-keras>=2.15.0,<3.0.0", # Keep keras 2 compatibility "tf2onnx>=1.16.0,<2.0.0", # cf. https://github.com/onnx/tensorflow-onnx/releases/tag/v1.16.0 ] torch = [ @@ -98,7 +97,6 @@ dev = [ # Tensorflow # cf. https://github.com/mindee/doctr/pull/1461 "tensorflow>=2.15.0,<3.0.0", - "tf-keras>=2.15.0,<3.0.0", # Keep keras 2 compatibility "tf2onnx>=1.16.0,<2.0.0", # cf. https://github.com/onnx/tensorflow-onnx/releases/tag/v1.16.0 # PyTorch "torch>=2.0.0,<3.0.0", diff --git a/references/classification/latency_tensorflow.py b/references/classification/latency_tensorflow.py index 9ed7e1603..6b08cab7a 100644 --- a/references/classification/latency_tensorflow.py +++ b/references/classification/latency_tensorflow.py @@ -9,10 +9,6 @@ import os import time -from doctr.file_utils import ensure_keras_v2 - -ensure_keras_v2() - os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" diff --git a/references/classification/train_tensorflow_character.py b/references/classification/train_tensorflow_character.py index 0b1b648d9..c6fe31213 100644 --- a/references/classification/train_tensorflow_character.py +++ b/references/classification/train_tensorflow_character.py @@ -5,10 +5,6 @@ import os -from doctr.file_utils import ensure_keras_v2 - -ensure_keras_v2() - os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" diff --git a/references/classification/train_tensorflow_orientation.py b/references/classification/train_tensorflow_orientation.py index 297a5674f..6d2e91dbf 100644 --- a/references/classification/train_tensorflow_orientation.py +++ b/references/classification/train_tensorflow_orientation.py @@ -5,10 +5,6 @@ import os -from doctr.file_utils import ensure_keras_v2 - -ensure_keras_v2() - os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" diff --git a/references/detection/evaluate_tensorflow.py b/references/detection/evaluate_tensorflow.py index a2c5bbe49..c0b2e8b2c 100644 --- a/references/detection/evaluate_tensorflow.py +++ b/references/detection/evaluate_tensorflow.py @@ -5,10 +5,6 @@ import os -from doctr.file_utils import ensure_keras_v2 - -ensure_keras_v2() - from doctr.file_utils import CLASS_NAME os.environ["USE_TF"] = "1" diff --git a/references/detection/latency_tensorflow.py b/references/detection/latency_tensorflow.py index 17cdf784a..35c3479e7 100644 --- a/references/detection/latency_tensorflow.py +++ b/references/detection/latency_tensorflow.py @@ -9,10 +9,6 @@ import os import time -from doctr.file_utils import ensure_keras_v2 - -ensure_keras_v2() - os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index f054879e8..ba2fd3903 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -5,10 +5,6 @@ import os -from doctr.file_utils import ensure_keras_v2 - -ensure_keras_v2() - os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" diff --git a/references/recognition/evaluate_tensorflow.py b/references/recognition/evaluate_tensorflow.py index b6ca50b51..805dafa53 100644 --- a/references/recognition/evaluate_tensorflow.py +++ b/references/recognition/evaluate_tensorflow.py @@ -5,10 +5,6 @@ import os -from doctr.file_utils import ensure_keras_v2 - -ensure_keras_v2() - os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" diff --git a/references/recognition/latency_tensorflow.py b/references/recognition/latency_tensorflow.py index 26bc2d6bc..6817f7632 100644 --- a/references/recognition/latency_tensorflow.py +++ b/references/recognition/latency_tensorflow.py @@ -9,10 +9,6 @@ import os import time -from doctr.file_utils import ensure_keras_v2 - -ensure_keras_v2() - os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" diff --git a/references/recognition/train_tensorflow.py b/references/recognition/train_tensorflow.py index 348f3a386..070bc1de3 100644 --- a/references/recognition/train_tensorflow.py +++ b/references/recognition/train_tensorflow.py @@ -5,10 +5,6 @@ import os -from doctr.file_utils import ensure_keras_v2 - -ensure_keras_v2() - os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" diff --git a/tests/tensorflow/test_models_factory.py b/tests/tensorflow/test_models_factory.py index a4483800c..5aa0142e6 100644 --- a/tests/tensorflow/test_models_factory.py +++ b/tests/tensorflow/test_models_factory.py @@ -1,12 +1,7 @@ -import json -import os -import tempfile - import pytest -import tensorflow as tf from doctr import models -from doctr.models.factory import _save_model_and_config_for_hf_hub, from_hub, push_to_hf_hub +from doctr.models.factory import push_to_hf_hub def test_push_to_hf_hub(): @@ -22,6 +17,7 @@ def test_push_to_hf_hub(): push_to_hf_hub(model, model_name="test", task="detection", arch="crnn_mobilenet_v3_large") +""" @pytest.mark.parametrize( "arch_name, task_name, dummy_model_id", [ @@ -68,3 +64,4 @@ def test_models_for_hub(arch_name, task_name, dummy_model_id, tmpdir): tf.keras.backend.clear_session() hub_model = from_hub(repo_id=dummy_model_id) assert isinstance(hub_model, type(model)) +""" From 9dca02f18bb9caac104f009f7b0aee2f8b792a87 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 24 Oct 2024 14:54:28 +0200 Subject: [PATCH 2/3] update --- .../source/using_doctr/using_model_export.rst | 2 +- doctr/io/image/tensorflow.py | 2 +- .../classification/magc_resnet/tensorflow.py | 8 +++----- .../classification/mobilenet/tensorflow.py | 18 ++++++++--------- .../classification/predictor/tensorflow.py | 2 +- .../classification/resnet/tensorflow.py | 20 ++++++++----------- .../classification/textnet/tensorflow.py | 11 +++++----- doctr/models/classification/vgg/tensorflow.py | 8 +++----- doctr/models/classification/vit/tensorflow.py | 9 ++++----- .../differentiable_binarization/tensorflow.py | 11 ++++------ doctr/models/detection/fast/tensorflow.py | 5 ++--- doctr/models/detection/linknet/tensorflow.py | 6 ++---- .../models/detection/predictor/tensorflow.py | 2 +- doctr/models/modules/layers/tensorflow.py | 2 +- .../models/modules/transformer/tensorflow.py | 2 +- .../modules/vision_transformer/tensorflow.py | 2 +- doctr/models/recognition/crnn/tensorflow.py | 13 ++++++------ doctr/models/recognition/master/tensorflow.py | 7 +++---- doctr/models/recognition/parseq/tensorflow.py | 7 +++---- doctr/models/recognition/sar/tensorflow.py | 8 ++++---- doctr/models/recognition/vitstr/tensorflow.py | 9 ++++----- doctr/models/utils/tensorflow.py | 19 ++++-------------- .../train_tensorflow_character.py | 2 +- .../train_tensorflow_orientation.py | 2 +- references/detection/evaluate_tensorflow.py | 2 +- references/detection/train_tensorflow.py | 2 +- references/recognition/evaluate_tensorflow.py | 2 +- references/recognition/train_tensorflow.py | 2 +- tests/tensorflow/test_models_detection_tf.py | 11 +++++----- tests/tensorflow/test_models_factory.py | 4 ++-- .../tensorflow/test_models_recognition_tf.py | 5 +++-- tests/tensorflow/test_models_utils_tf.py | 4 ++-- 32 files changed, 89 insertions(+), 120 deletions(-) diff --git a/docs/source/using_doctr/using_model_export.rst b/docs/source/using_doctr/using_model_export.rst index c62c36169..48f570f69 100644 --- a/docs/source/using_doctr/using_model_export.rst +++ b/docs/source/using_doctr/using_model_export.rst @@ -31,7 +31,7 @@ Advantages: .. code:: python3 import tensorflow as tf - from tensorflow.keras import mixed_precision + from keras import mixed_precision mixed_precision.set_global_policy('mixed_float16') predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True) diff --git a/doctr/io/image/tensorflow.py b/doctr/io/image/tensorflow.py index 28fb2fadd..3b1f1ed0e 100644 --- a/doctr/io/image/tensorflow.py +++ b/doctr/io/image/tensorflow.py @@ -7,8 +7,8 @@ import numpy as np import tensorflow as tf +from keras.utils import img_to_array from PIL import Image -from tensorflow.keras.utils import img_to_array from doctr.utils.common_types import AbstractPath diff --git a/doctr/models/classification/magc_resnet/tensorflow.py b/doctr/models/classification/magc_resnet/tensorflow.py index 0aecdff3c..8732ec09a 100644 --- a/doctr/models/classification/magc_resnet/tensorflow.py +++ b/doctr/models/classification/magc_resnet/tensorflow.py @@ -9,12 +9,11 @@ from typing import Any, Dict, List, Optional, Tuple import tensorflow as tf -from tensorflow.keras import activations, layers -from tensorflow.keras.models import Sequential +from keras import Sequential, activations, layers from doctr.datasets import VOCABS -from ...utils import _build_model, load_pretrained_params +from ...utils import load_pretrained_params from ..resnet.tensorflow import ResNet __all__ = ["magc_resnet31"] @@ -26,7 +25,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/magc_resnet31-6c266055.weights.h5", }, } @@ -152,7 +151,6 @@ def _magc_resnet( cfg=_cfg, **kwargs, ) - _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/classification/mobilenet/tensorflow.py b/doctr/models/classification/mobilenet/tensorflow.py index dec338afe..6aece9d1f 100644 --- a/doctr/models/classification/mobilenet/tensorflow.py +++ b/doctr/models/classification/mobilenet/tensorflow.py @@ -9,11 +9,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union import tensorflow as tf -from tensorflow.keras import layers -from tensorflow.keras.models import Sequential +from keras import Sequential, layers from ....datasets import VOCABS -from ...utils import _build_model, conv_sequence, load_pretrained_params +from ...utils import conv_sequence, load_pretrained_params __all__ = [ "MobileNetV3", @@ -32,42 +31,42 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/mobilenet_v3_large-d857506e.weights.h5", }, "mobilenet_v3_large_r": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/mobilenet_v3_large_r-eef2e3c6.weights.h5", }, "mobilenet_v3_small": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/mobilenet_v3_small-3fcebad7.weights.h5", }, "mobilenet_v3_small_r": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/mobilenet_v3_small_r-dd50218d.weights.h5", }, "mobilenet_v3_small_crop_orientation": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (128, 128, 3), "classes": [0, -90, 180, 90], - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/mobilenet_v3_small_crop_orientation-ef019b6b.weights.h5", }, "mobilenet_v3_small_page_orientation": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (512, 512, 3), "classes": [0, -90, 180, 90], - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/mobilenet_v3_small_page_orientation-0071d55d.weights.h5", }, } @@ -295,7 +294,6 @@ def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwa cfg=_cfg, **kwargs, ) - _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/classification/predictor/tensorflow.py b/doctr/models/classification/predictor/tensorflow.py index 23efbf657..ba26e1db5 100644 --- a/doctr/models/classification/predictor/tensorflow.py +++ b/doctr/models/classification/predictor/tensorflow.py @@ -7,7 +7,7 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import Model +from keras import Model from doctr.models.preprocessor import PreProcessor from doctr.utils.repr import NestedObject diff --git a/doctr/models/classification/resnet/tensorflow.py b/doctr/models/classification/resnet/tensorflow.py index 21d75a13e..ae0188cb4 100644 --- a/doctr/models/classification/resnet/tensorflow.py +++ b/doctr/models/classification/resnet/tensorflow.py @@ -7,13 +7,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import tensorflow as tf -from tensorflow.keras import layers -from tensorflow.keras.applications import ResNet50 -from tensorflow.keras.models import Sequential +from keras import Sequential, applications, layers from doctr.datasets import VOCABS -from ...utils import _build_model, conv_sequence, load_pretrained_params +from ...utils import conv_sequence, load_pretrained_params __all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide"] @@ -24,35 +22,35 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/resnet18-4138682e.weights.h5", }, "resnet31": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/resnet31-61808f41.weights.h5", }, "resnet34": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/resnet34-2288ee52.weights.h5", }, "resnet50": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/resnet50-82358f34.weights.h5", }, "resnet34_wide": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/resnet34_wide-4c788e90.weights.h5", }, } @@ -210,7 +208,6 @@ def _resnet( model = ResNet( num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling, origin_stem, cfg=_cfg, **kwargs ) - _build_model(model) # Load pretrained parameters if pretrained: @@ -350,7 +347,7 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet: _cfg["input_shape"] = kwargs["input_shape"] kwargs.pop("classes") - model = ResNet50( + model = applications.ResNet50( weights=None, include_top=True, pooling=True, @@ -360,7 +357,6 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet: ) model.cfg = _cfg - _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/classification/textnet/tensorflow.py b/doctr/models/classification/textnet/tensorflow.py index df6e60bad..b7d900647 100644 --- a/doctr/models/classification/textnet/tensorflow.py +++ b/doctr/models/classification/textnet/tensorflow.py @@ -7,12 +7,12 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple -from tensorflow.keras import Sequential, layers +from keras import Sequential, layers from doctr.datasets import VOCABS from ...modules.layers.tensorflow import FASTConvLayer -from ...utils import _build_model, conv_sequence, load_pretrained_params +from ...utils import conv_sequence, load_pretrained_params __all__ = ["textnet_tiny", "textnet_small", "textnet_base"] @@ -22,21 +22,21 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/textnet_tiny-99fb9158.weights.h5", }, "textnet_small": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/textnet_small-44072f65.weights.h5", }, "textnet_base": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/textnet_base-a92df1c0.weights.h5", }, } @@ -111,7 +111,6 @@ def _textnet( # Build the model model = TextNet(cfg=_cfg, **kwargs) - _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/classification/vgg/tensorflow.py b/doctr/models/classification/vgg/tensorflow.py index 22d57272b..d0e2af404 100644 --- a/doctr/models/classification/vgg/tensorflow.py +++ b/doctr/models/classification/vgg/tensorflow.py @@ -6,12 +6,11 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple -from tensorflow.keras import layers -from tensorflow.keras.models import Sequential +from keras import Sequential, layers from doctr.datasets import VOCABS -from ...utils import _build_model, conv_sequence, load_pretrained_params +from ...utils import conv_sequence, load_pretrained_params __all__ = ["VGG", "vgg16_bn_r"] @@ -22,7 +21,7 @@ "std": (1.0, 1.0, 1.0), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/vgg16_bn_r-b4d69212.weights.h5", }, } @@ -81,7 +80,6 @@ def _vgg( # Build the model model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs) - _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/classification/vit/tensorflow.py b/doctr/models/classification/vit/tensorflow.py index 26e5a64db..1ddf76572 100644 --- a/doctr/models/classification/vit/tensorflow.py +++ b/doctr/models/classification/vit/tensorflow.py @@ -7,14 +7,14 @@ from typing import Any, Dict, Optional, Tuple import tensorflow as tf -from tensorflow.keras import Sequential, layers +from keras import Sequential, layers from doctr.datasets import VOCABS from doctr.models.modules.transformer import EncoderBlock from doctr.models.modules.vision_transformer.tensorflow import PatchEmbedding from doctr.utils.repr import NestedObject -from ...utils import _build_model, load_pretrained_params +from ...utils import load_pretrained_params __all__ = ["vit_s", "vit_b"] @@ -25,14 +25,14 @@ "std": (0.299, 0.296, 0.301), "input_shape": (3, 32, 32), "classes": list(VOCABS["french"]), - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/vit_s-d68b3d5b.weights.h5", }, "vit_b": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/vit_b-f01181f0.weights.h5", }, } @@ -121,7 +121,6 @@ def _vit( # Build the model model = VisionTransformer(cfg=_cfg, **kwargs) - _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 805ee0bf5..a4c4f342c 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -10,14 +10,12 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import Model, Sequential, layers, losses -from tensorflow.keras.applications import ResNet50 +from keras import Model, Sequential, applications, layers, losses from doctr.file_utils import CLASS_NAME from doctr.models.utils import ( IntermediateLayerGetter, _bf16_to_float32, - _build_model, conv_sequence, load_pretrained_params, ) @@ -34,7 +32,7 @@ "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/db_resnet50-fe92475b.weights.h5", }, "db_mobilenet_v3_large": { "mean": (0.798, 0.785, 0.772), @@ -310,7 +308,6 @@ def _db_resnet( # Build the model model = DBNet(feat_extractor, cfg=_cfg, **kwargs) - _build_model(model) # Load pretrained parameters if pretrained: @@ -355,7 +352,7 @@ def _db_mobilenet( # Build the model model = DBNet(feat_extractor, cfg=_cfg, **kwargs) - _build_model(model) + # Load pretrained parameters if pretrained: # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning @@ -390,7 +387,7 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet: return _db_resnet( "db_resnet50", pretrained, - ResNet50, + applications.ResNet50, ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"], **kwargs, ) diff --git a/doctr/models/detection/fast/tensorflow.py b/doctr/models/detection/fast/tensorflow.py index 607b8d25b..417244841 100644 --- a/doctr/models/detection/fast/tensorflow.py +++ b/doctr/models/detection/fast/tensorflow.py @@ -10,10 +10,10 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import Model, Sequential, layers +from keras import Model, Sequential, layers from doctr.file_utils import CLASS_NAME -from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, _build_model, load_pretrained_params +from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, load_pretrained_params from doctr.utils.repr import NestedObject from ...classification import textnet_base, textnet_small, textnet_tiny @@ -333,7 +333,6 @@ def _fast( # Build the model model = FAST(feat_extractor, cfg=_cfg, **kwargs) - _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index 667508ef9..144c25be1 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -10,14 +10,13 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import Model, Sequential, layers, losses +from keras import Model, Sequential, layers, losses from doctr.file_utils import CLASS_NAME from doctr.models.classification import resnet18, resnet34, resnet50 from doctr.models.utils import ( IntermediateLayerGetter, _bf16_to_float32, - _build_model, conv_sequence, load_pretrained_params, ) @@ -44,7 +43,7 @@ "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/linknet_resnet50-fdea2b5f.weights.h5", }, } @@ -280,7 +279,6 @@ def _linknet( # Build the model model = LinkNet(feat_extractor, cfg=_cfg, **kwargs) - _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/detection/predictor/tensorflow.py b/doctr/models/detection/predictor/tensorflow.py index a3d508584..a7ccd4a9a 100644 --- a/doctr/models/detection/predictor/tensorflow.py +++ b/doctr/models/detection/predictor/tensorflow.py @@ -7,7 +7,7 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import Model +from keras import Model from doctr.models.detection._utils import _remove_padding from doctr.models.preprocessor import PreProcessor diff --git a/doctr/models/modules/layers/tensorflow.py b/doctr/models/modules/layers/tensorflow.py index 68849fbf6..b1019be77 100644 --- a/doctr/models/modules/layers/tensorflow.py +++ b/doctr/models/modules/layers/tensorflow.py @@ -7,7 +7,7 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import layers +from keras import layers from doctr.utils.repr import NestedObject diff --git a/doctr/models/modules/transformer/tensorflow.py b/doctr/models/modules/transformer/tensorflow.py index 50c7cef04..3fc1d17f2 100644 --- a/doctr/models/modules/transformer/tensorflow.py +++ b/doctr/models/modules/transformer/tensorflow.py @@ -7,7 +7,7 @@ from typing import Any, Callable, Optional, Tuple import tensorflow as tf -from tensorflow.keras import layers +from keras import layers from doctr.utils.repr import NestedObject diff --git a/doctr/models/modules/vision_transformer/tensorflow.py b/doctr/models/modules/vision_transformer/tensorflow.py index 8386172eb..a73aa4c70 100644 --- a/doctr/models/modules/vision_transformer/tensorflow.py +++ b/doctr/models/modules/vision_transformer/tensorflow.py @@ -7,7 +7,7 @@ from typing import Any, Tuple import tensorflow as tf -from tensorflow.keras import layers +from keras import layers from doctr.utils.repr import NestedObject diff --git a/doctr/models/recognition/crnn/tensorflow.py b/doctr/models/recognition/crnn/tensorflow.py index c73e5bf36..3787bd014 100644 --- a/doctr/models/recognition/crnn/tensorflow.py +++ b/doctr/models/recognition/crnn/tensorflow.py @@ -7,13 +7,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union import tensorflow as tf -from tensorflow.keras import layers -from tensorflow.keras.models import Model, Sequential +from keras import Model, Sequential, layers from doctr.datasets import VOCABS from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r -from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params from ..core import RecognitionModel, RecognitionPostProcessor __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"] @@ -24,21 +23,21 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/crnn_vgg16_bn-41bbe57b.weights.h5", }, "crnn_mobilenet_v3_small": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/crnn_mobilenet_v3_small-b4bb2858.weights.h5", }, "crnn_mobilenet_v3_large": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/crnn_mobilenet_v3_large-1eac49ae.weights.h5", }, } @@ -245,7 +244,7 @@ def _crnn( # Build the model model = CRNN(feat_extractor, cfg=_cfg, **kwargs) - _build_model(model) + # Load pretrained parameters if pretrained: # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning diff --git a/doctr/models/recognition/master/tensorflow.py b/doctr/models/recognition/master/tensorflow.py index 8574a6f6d..a7a2adebb 100644 --- a/doctr/models/recognition/master/tensorflow.py +++ b/doctr/models/recognition/master/tensorflow.py @@ -7,13 +7,13 @@ from typing import Any, Dict, List, Optional, Tuple import tensorflow as tf -from tensorflow.keras import Model, layers +from keras import Model, layers from doctr.datasets import VOCABS from doctr.models.classification import magc_resnet31 from doctr.models.modules.transformer import Decoder, PositionalEncoding -from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params from .base import _MASTER, _MASTERPostProcessor __all__ = ["MASTER", "master"] @@ -25,7 +25,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/master-bdcf6f40.weights.h5", }, } @@ -290,7 +290,6 @@ def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool cfg=_cfg, **kwargs, ) - _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py index 80bcb8550..89ae314a8 100644 --- a/doctr/models/recognition/parseq/tensorflow.py +++ b/doctr/models/recognition/parseq/tensorflow.py @@ -10,13 +10,13 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import Model, layers +from keras import Model, layers from doctr.datasets import VOCABS from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward from ...classification import vit_s -from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params from .base import _PARSeq, _PARSeqPostProcessor __all__ = ["PARSeq", "parseq"] @@ -27,7 +27,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/parseq-3a3149e7.weights.h5", }, } @@ -473,7 +473,6 @@ def _parseq( # Build the model model = PARSeq(feat_extractor, cfg=_cfg, **kwargs) - _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/recognition/sar/tensorflow.py b/doctr/models/recognition/sar/tensorflow.py index 1e3c6f7a2..34303e080 100644 --- a/doctr/models/recognition/sar/tensorflow.py +++ b/doctr/models/recognition/sar/tensorflow.py @@ -7,13 +7,13 @@ from typing import Any, Dict, List, Optional, Tuple import tensorflow as tf -from tensorflow.keras import Model, Sequential, layers +from keras import Model, Sequential, layers from doctr.datasets import VOCABS from doctr.utils.repr import NestedObject from ...classification import resnet31 -from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params from ..core import RecognitionModel, RecognitionPostProcessor __all__ = ["SAR", "sar_resnet31"] @@ -24,7 +24,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/sar_resnet31-861e9563.weights.h5", }, } @@ -390,7 +390,7 @@ def _sar( # Build the model model = SAR(feat_extractor, cfg=_cfg, **kwargs) - _build_model(model) + # Load pretrained parameters if pretrained: # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning diff --git a/doctr/models/recognition/vitstr/tensorflow.py b/doctr/models/recognition/vitstr/tensorflow.py index 222f056d8..0697dd97e 100644 --- a/doctr/models/recognition/vitstr/tensorflow.py +++ b/doctr/models/recognition/vitstr/tensorflow.py @@ -7,12 +7,12 @@ from typing import Any, Dict, List, Optional, Tuple import tensorflow as tf -from tensorflow.keras import Model, layers +from keras import Model, layers from doctr.datasets import VOCABS from ...classification import vit_b, vit_s -from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params from .base import _ViTSTR, _ViTSTRPostProcessor __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"] @@ -23,14 +23,14 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/vitstr_small-c692a250.weights.h5", }, "vitstr_base": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v0.10.0/vitstr_base-636fcfcf.weights.h5", }, } @@ -216,7 +216,6 @@ def _vitstr( # Build the model model = ViTSTR(feat_extractor, cfg=_cfg, **kwargs) - _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py index 6c2137cd4..fd6c3fbb6 100644 --- a/doctr/models/utils/tensorflow.py +++ b/doctr/models/utils/tensorflow.py @@ -8,7 +8,7 @@ import tensorflow as tf import tf2onnx -from tensorflow.keras import Model, layers +from keras import Model, layers from doctr.utils.data import download_from_url @@ -17,7 +17,6 @@ __all__ = [ "load_pretrained_params", - "_build_model", "conv_sequence", "IntermediateLayerGetter", "export_model_to_onnx", @@ -35,16 +34,6 @@ def _bf16_to_float32(x: tf.Tensor) -> tf.Tensor: return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x -def _build_model(model: Model): - """Build a model by calling it once with dummy input - - Args: - ---- - model: the model to be built - """ - model(tf.zeros((1, *model.cfg["input_shape"])), training=False) - - def load_pretrained_params( model: Model, url: Optional[str] = None, @@ -83,7 +72,7 @@ def conv_sequence( ) -> List[layers.Layer]: """Builds a convolutional-based layer sequence - >>> from tensorflow.keras import Sequential + >>> from keras import Sequential >>> from doctr.models import conv_sequence >>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3])) @@ -119,10 +108,10 @@ def conv_sequence( class IntermediateLayerGetter(Model): """Implements an intermediate layer getter - >>> from tensorflow.keras.applications import ResNet50 + >>> from keras import applications >>> from doctr.models import IntermediateLayerGetter >>> target_layers = ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"] - >>> feat_extractor = IntermediateLayerGetter(ResNet50(include_top=False, pooling=False), target_layers) + >>> feat_extractor = IntermediateLayerGetter(applications.ResNet50(include_top=False, pooling=False), target_layers) Args: ---- diff --git a/references/classification/train_tensorflow_character.py b/references/classification/train_tensorflow_character.py index c6fe31213..225cd83fc 100644 --- a/references/classification/train_tensorflow_character.py +++ b/references/classification/train_tensorflow_character.py @@ -13,7 +13,7 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import Model, mixed_precision, optimizers +from keras import Model, mixed_precision, optimizers from tqdm.auto import tqdm from doctr.models import login_to_hub, push_to_hf_hub diff --git a/references/classification/train_tensorflow_orientation.py b/references/classification/train_tensorflow_orientation.py index 6d2e91dbf..e53de0ef3 100644 --- a/references/classification/train_tensorflow_orientation.py +++ b/references/classification/train_tensorflow_orientation.py @@ -13,7 +13,7 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import Model, mixed_precision, optimizers +from keras import Model, mixed_precision, optimizers from tqdm.auto import tqdm from doctr.models import login_to_hub, push_to_hf_hub diff --git a/references/detection/evaluate_tensorflow.py b/references/detection/evaluate_tensorflow.py index c0b2e8b2c..882e9654c 100644 --- a/references/detection/evaluate_tensorflow.py +++ b/references/detection/evaluate_tensorflow.py @@ -14,7 +14,7 @@ from pathlib import Path import tensorflow as tf -from tensorflow.keras import mixed_precision +from keras import mixed_precision from tqdm import tqdm gpu_devices = tf.config.list_physical_devices("GPU") diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index ba2fd3903..073d617bf 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -14,7 +14,7 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import Model, mixed_precision, optimizers +from keras import Model, mixed_precision, optimizers from tqdm.auto import tqdm from doctr.models import login_to_hub, push_to_hf_hub diff --git a/references/recognition/evaluate_tensorflow.py b/references/recognition/evaluate_tensorflow.py index 805dafa53..2a76dc090 100644 --- a/references/recognition/evaluate_tensorflow.py +++ b/references/recognition/evaluate_tensorflow.py @@ -11,7 +11,7 @@ import time import tensorflow as tf -from tensorflow.keras import mixed_precision +from keras import mixed_precision from tqdm import tqdm gpu_devices = tf.config.list_physical_devices("GPU") diff --git a/references/recognition/train_tensorflow.py b/references/recognition/train_tensorflow.py index 070bc1de3..347614361 100644 --- a/references/recognition/train_tensorflow.py +++ b/references/recognition/train_tensorflow.py @@ -15,7 +15,7 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import Model, mixed_precision, optimizers +from keras import Model, mixed_precision, optimizers from tqdm.auto import tqdm from doctr.models import login_to_hub, push_to_hf_hub diff --git a/tests/tensorflow/test_models_detection_tf.py b/tests/tensorflow/test_models_detection_tf.py index 7dbb090bf..c43903c2b 100644 --- a/tests/tensorflow/test_models_detection_tf.py +++ b/tests/tensorflow/test_models_detection_tf.py @@ -2,6 +2,7 @@ import os import tempfile +import keras import numpy as np import onnxruntime import psutil @@ -37,13 +38,13 @@ ) def test_detection_models(arch_name, input_shape, output_size, out_prob, train_mode): batch_size = 2 - tf.keras.backend.clear_session() + keras.backend.clear_session() if arch_name == "fast_tiny_rep": model = reparameterize(detection.fast_tiny(pretrained=True, input_shape=input_shape)) train_mode = False # Reparameterized model is not trainable else: model = detection.__dict__[arch_name](pretrained=True, input_shape=input_shape) - assert isinstance(model, tf.keras.Model) + assert isinstance(model, keras.Model) input_tensor = tf.random.uniform(shape=[batch_size, *input_shape], minval=0, maxval=1) target = [ {CLASS_NAME: np.array([[0.5, 0.5, 1, 1], [0.5, 0.5, 0.8, 0.8]], dtype=np.float32)}, @@ -152,7 +153,7 @@ def test_rotated_detectionpredictor(mock_pdf): ) def test_detection_zoo(arch_name): # Model - tf.keras.backend.clear_session() + keras.backend.clear_session() predictor = detection.zoo.detection_predictor(arch_name, pretrained=False) # object check assert isinstance(predictor, DetectionPredictor) @@ -177,7 +178,7 @@ def test_fast_reparameterization(): base_model_params = np.sum([np.prod(v.shape) for v in base_model.trainable_variables]) assert math.isclose(base_model_params, 13535296) # base model params base_out = base_model(dummy_input, training=False)["logits"] - tf.keras.backend.clear_session() + keras.backend.clear_session() rep_model = reparameterize(base_model) rep_model_params = np.sum([np.prod(v.shape) for v in base_model.trainable_variables]) assert math.isclose(rep_model_params, 8520256) # reparameterized model params @@ -241,7 +242,7 @@ def test_dilate(): def test_models_onnx_export(arch_name, input_shape, output_size): # Model batch_size = 2 - tf.keras.backend.clear_session() + keras.backend.clear_session() if arch_name == "fast_tiny_rep": model = reparameterize(detection.fast_tiny(pretrained=True, exportable=True, input_shape=input_shape)) else: diff --git a/tests/tensorflow/test_models_factory.py b/tests/tensorflow/test_models_factory.py index 5aa0142e6..3f4a5c53e 100644 --- a/tests/tensorflow/test_models_factory.py +++ b/tests/tensorflow/test_models_factory.py @@ -46,7 +46,7 @@ def test_push_to_hf_hub(): ) def test_models_for_hub(arch_name, task_name, dummy_model_id, tmpdir): with tempfile.TemporaryDirectory() as tmp_dir: - tf.keras.backend.clear_session() + keras.backend.clear_session() model = models.__dict__[task_name].__dict__[arch_name](pretrained=True) _save_model_and_config_for_hf_hub(model, arch=arch_name, task=task_name, save_dir=tmp_dir) @@ -61,7 +61,7 @@ def test_models_for_hub(arch_name, task_name, dummy_model_id, tmpdir): assert all(key in model.cfg.keys() for key in tmp_config.keys()) # test from hub - tf.keras.backend.clear_session() + keras.backend.clear_session() hub_model = from_hub(repo_id=dummy_model_id) assert isinstance(hub_model, type(model)) """ diff --git a/tests/tensorflow/test_models_recognition_tf.py b/tests/tensorflow/test_models_recognition_tf.py index 162c446d3..7da1cb534 100644 --- a/tests/tensorflow/test_models_recognition_tf.py +++ b/tests/tensorflow/test_models_recognition_tf.py @@ -2,6 +2,7 @@ import shutil import tempfile +import keras import numpy as np import onnxruntime import psutil @@ -40,7 +41,7 @@ def test_recognition_models(arch_name, input_shape, train_mode, mock_vocab): batch_size = 4 reco_model = recognition.__dict__[arch_name](vocab=mock_vocab, pretrained=True, input_shape=input_shape) - assert isinstance(reco_model, tf.keras.Model) + assert isinstance(reco_model, keras.Model) input_tensor = tf.random.uniform(shape=[batch_size, *input_shape], minval=0, maxval=1) target = ["i", "am", "a", "jedi"] @@ -194,7 +195,7 @@ def test_recognition_zoo_error(): def test_models_onnx_export(arch_name, input_shape): # Model batch_size = 2 - tf.keras.backend.clear_session() + keras.backend.clear_session() model = recognition.__dict__[arch_name](pretrained=True, exportable=True, input_shape=input_shape) # SAR, MASTER, ViTSTR export currently only available with constant batch size if arch_name in ["sar_resnet31", "master", "vitstr_small", "parseq"]: diff --git a/tests/tensorflow/test_models_utils_tf.py b/tests/tensorflow/test_models_utils_tf.py index 4783a09b4..d547f1c72 100644 --- a/tests/tensorflow/test_models_utils_tf.py +++ b/tests/tensorflow/test_models_utils_tf.py @@ -2,7 +2,7 @@ import pytest import tensorflow as tf -from tensorflow.keras.applications import ResNet50 +from keras import applications from doctr.models.classification import mobilenet_v3_small from doctr.models.utils import ( @@ -49,7 +49,7 @@ def test_conv_sequence(): def test_intermediate_layer_getter(): - backbone = ResNet50(include_top=False, weights=None, pooling=None) + backbone = applications.ResNet50(include_top=False, weights=None, pooling=None) feat_extractor = IntermediateLayerGetter(backbone, ["conv2_block3_out", "conv3_block4_out"]) # Check num of output features input_tensor = tf.random.uniform(shape=[1, 224, 224, 3], minval=0, maxval=1) From f04c567823249e0c233c55b11f5174f3641730f2 Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 25 Oct 2024 14:58:11 +0200 Subject: [PATCH 3/3] restore build before --- doctr/models/classification/magc_resnet/tensorflow.py | 3 ++- doctr/models/classification/mobilenet/tensorflow.py | 3 ++- doctr/models/classification/resnet/tensorflow.py | 4 +++- doctr/models/classification/textnet/tensorflow.py | 3 ++- doctr/models/classification/vgg/tensorflow.py | 3 ++- doctr/models/classification/vit/pytorch.py | 3 +-- doctr/models/classification/vit/tensorflow.py | 3 ++- .../differentiable_binarization/tensorflow.py | 3 +++ doctr/models/detection/fast/tensorflow.py | 3 ++- doctr/models/detection/linknet/tensorflow.py | 2 ++ doctr/models/recognition/crnn/tensorflow.py | 3 ++- doctr/models/recognition/master/tensorflow.py | 3 ++- doctr/models/recognition/parseq/tensorflow.py | 3 ++- doctr/models/recognition/sar/tensorflow.py | 3 ++- doctr/models/recognition/vitstr/tensorflow.py | 3 ++- doctr/models/utils/tensorflow.py | 10 ++++++++++ tests/tensorflow/test_models_classification_tf.py | 9 +++++---- 17 files changed, 46 insertions(+), 18 deletions(-) diff --git a/doctr/models/classification/magc_resnet/tensorflow.py b/doctr/models/classification/magc_resnet/tensorflow.py index 8732ec09a..51c212ba3 100644 --- a/doctr/models/classification/magc_resnet/tensorflow.py +++ b/doctr/models/classification/magc_resnet/tensorflow.py @@ -13,7 +13,7 @@ from doctr.datasets import VOCABS -from ...utils import load_pretrained_params +from ...utils import _build_model, load_pretrained_params from ..resnet.tensorflow import ResNet __all__ = ["magc_resnet31"] @@ -151,6 +151,7 @@ def _magc_resnet( cfg=_cfg, **kwargs, ) + _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/classification/mobilenet/tensorflow.py b/doctr/models/classification/mobilenet/tensorflow.py index 6aece9d1f..2e7d03510 100644 --- a/doctr/models/classification/mobilenet/tensorflow.py +++ b/doctr/models/classification/mobilenet/tensorflow.py @@ -12,7 +12,7 @@ from keras import Sequential, layers from ....datasets import VOCABS -from ...utils import conv_sequence, load_pretrained_params +from ...utils import _build_model, conv_sequence, load_pretrained_params __all__ = [ "MobileNetV3", @@ -294,6 +294,7 @@ def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwa cfg=_cfg, **kwargs, ) + _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/classification/resnet/tensorflow.py b/doctr/models/classification/resnet/tensorflow.py index ae0188cb4..c8b618f6b 100644 --- a/doctr/models/classification/resnet/tensorflow.py +++ b/doctr/models/classification/resnet/tensorflow.py @@ -11,7 +11,7 @@ from doctr.datasets import VOCABS -from ...utils import conv_sequence, load_pretrained_params +from ...utils import _build_model, conv_sequence, load_pretrained_params __all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide"] @@ -208,6 +208,7 @@ def _resnet( model = ResNet( num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling, origin_stem, cfg=_cfg, **kwargs ) + _build_model(model) # Load pretrained parameters if pretrained: @@ -357,6 +358,7 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet: ) model.cfg = _cfg + _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/classification/textnet/tensorflow.py b/doctr/models/classification/textnet/tensorflow.py index b7d900647..1907bbecc 100644 --- a/doctr/models/classification/textnet/tensorflow.py +++ b/doctr/models/classification/textnet/tensorflow.py @@ -12,7 +12,7 @@ from doctr.datasets import VOCABS from ...modules.layers.tensorflow import FASTConvLayer -from ...utils import conv_sequence, load_pretrained_params +from ...utils import _build_model, conv_sequence, load_pretrained_params __all__ = ["textnet_tiny", "textnet_small", "textnet_base"] @@ -111,6 +111,7 @@ def _textnet( # Build the model model = TextNet(cfg=_cfg, **kwargs) + _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/classification/vgg/tensorflow.py b/doctr/models/classification/vgg/tensorflow.py index d0e2af404..31d612fd6 100644 --- a/doctr/models/classification/vgg/tensorflow.py +++ b/doctr/models/classification/vgg/tensorflow.py @@ -10,7 +10,7 @@ from doctr.datasets import VOCABS -from ...utils import conv_sequence, load_pretrained_params +from ...utils import _build_model, conv_sequence, load_pretrained_params __all__ = ["VGG", "vgg16_bn_r"] @@ -80,6 +80,7 @@ def _vgg( # Build the model model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs) + _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/classification/vit/pytorch.py b/doctr/models/classification/vit/pytorch.py index 335e92559..42a4d68ce 100644 --- a/doctr/models/classification/vit/pytorch.py +++ b/doctr/models/classification/vit/pytorch.py @@ -10,8 +10,7 @@ from torch import nn from doctr.datasets import VOCABS -from doctr.models.modules.transformer import EncoderBlock -from doctr.models.modules.vision_transformer.pytorch import PatchEmbedding +from doctr.models.modules import EncoderBlock, PatchEmbedding from ...utils.pytorch import load_pretrained_params diff --git a/doctr/models/classification/vit/tensorflow.py b/doctr/models/classification/vit/tensorflow.py index 1ddf76572..60faee729 100644 --- a/doctr/models/classification/vit/tensorflow.py +++ b/doctr/models/classification/vit/tensorflow.py @@ -14,7 +14,7 @@ from doctr.models.modules.vision_transformer.tensorflow import PatchEmbedding from doctr.utils.repr import NestedObject -from ...utils import load_pretrained_params +from ...utils import _build_model, load_pretrained_params __all__ = ["vit_s", "vit_b"] @@ -121,6 +121,7 @@ def _vit( # Build the model model = VisionTransformer(cfg=_cfg, **kwargs) + _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index a4c4f342c..d71050f5c 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -16,6 +16,7 @@ from doctr.models.utils import ( IntermediateLayerGetter, _bf16_to_float32, + _build_model, conv_sequence, load_pretrained_params, ) @@ -308,6 +309,7 @@ def _db_resnet( # Build the model model = DBNet(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) # Load pretrained parameters if pretrained: @@ -352,6 +354,7 @@ def _db_mobilenet( # Build the model model = DBNet(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/detection/fast/tensorflow.py b/doctr/models/detection/fast/tensorflow.py index 417244841..be9365d02 100644 --- a/doctr/models/detection/fast/tensorflow.py +++ b/doctr/models/detection/fast/tensorflow.py @@ -13,7 +13,7 @@ from keras import Model, Sequential, layers from doctr.file_utils import CLASS_NAME -from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, load_pretrained_params +from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, _build_model, load_pretrained_params from doctr.utils.repr import NestedObject from ...classification import textnet_base, textnet_small, textnet_tiny @@ -333,6 +333,7 @@ def _fast( # Build the model model = FAST(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index 144c25be1..9b269cd48 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -17,6 +17,7 @@ from doctr.models.utils import ( IntermediateLayerGetter, _bf16_to_float32, + _build_model, conv_sequence, load_pretrained_params, ) @@ -279,6 +280,7 @@ def _linknet( # Build the model model = LinkNet(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/recognition/crnn/tensorflow.py b/doctr/models/recognition/crnn/tensorflow.py index 3787bd014..27db4b5c1 100644 --- a/doctr/models/recognition/crnn/tensorflow.py +++ b/doctr/models/recognition/crnn/tensorflow.py @@ -12,7 +12,7 @@ from doctr.datasets import VOCABS from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r -from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params from ..core import RecognitionModel, RecognitionPostProcessor __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"] @@ -244,6 +244,7 @@ def _crnn( # Build the model model = CRNN(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/recognition/master/tensorflow.py b/doctr/models/recognition/master/tensorflow.py index a7a2adebb..34df81ea2 100644 --- a/doctr/models/recognition/master/tensorflow.py +++ b/doctr/models/recognition/master/tensorflow.py @@ -13,7 +13,7 @@ from doctr.models.classification import magc_resnet31 from doctr.models.modules.transformer import Decoder, PositionalEncoding -from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params from .base import _MASTER, _MASTERPostProcessor __all__ = ["MASTER", "master"] @@ -290,6 +290,7 @@ def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool cfg=_cfg, **kwargs, ) + _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py index 89ae314a8..566f8a298 100644 --- a/doctr/models/recognition/parseq/tensorflow.py +++ b/doctr/models/recognition/parseq/tensorflow.py @@ -16,7 +16,7 @@ from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward from ...classification import vit_s -from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params from .base import _PARSeq, _PARSeqPostProcessor __all__ = ["PARSeq", "parseq"] @@ -473,6 +473,7 @@ def _parseq( # Build the model model = PARSeq(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/recognition/sar/tensorflow.py b/doctr/models/recognition/sar/tensorflow.py index 34303e080..5691f4d54 100644 --- a/doctr/models/recognition/sar/tensorflow.py +++ b/doctr/models/recognition/sar/tensorflow.py @@ -13,7 +13,7 @@ from doctr.utils.repr import NestedObject from ...classification import resnet31 -from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params from ..core import RecognitionModel, RecognitionPostProcessor __all__ = ["SAR", "sar_resnet31"] @@ -390,6 +390,7 @@ def _sar( # Build the model model = SAR(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/recognition/vitstr/tensorflow.py b/doctr/models/recognition/vitstr/tensorflow.py index 0697dd97e..d32c49a15 100644 --- a/doctr/models/recognition/vitstr/tensorflow.py +++ b/doctr/models/recognition/vitstr/tensorflow.py @@ -12,7 +12,7 @@ from doctr.datasets import VOCABS from ...classification import vit_b, vit_s -from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params from .base import _ViTSTR, _ViTSTRPostProcessor __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"] @@ -216,6 +216,7 @@ def _vitstr( # Build the model model = ViTSTR(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py index fd6c3fbb6..2a0e4e4c5 100644 --- a/doctr/models/utils/tensorflow.py +++ b/doctr/models/utils/tensorflow.py @@ -17,6 +17,7 @@ __all__ = [ "load_pretrained_params", + "_build_model", "conv_sequence", "IntermediateLayerGetter", "export_model_to_onnx", @@ -34,6 +35,15 @@ def _bf16_to_float32(x: tf.Tensor) -> tf.Tensor: return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x +def _build_model(model: Model): + """Build a model by calling it once with dummy input + Args: + ---- + model: the model to be built + """ + model(tf.zeros((1, *model.cfg["input_shape"])), training=False) + + def load_pretrained_params( model: Model, url: Optional[str] = None, diff --git a/tests/tensorflow/test_models_classification_tf.py b/tests/tensorflow/test_models_classification_tf.py index 89181aace..654ba082f 100644 --- a/tests/tensorflow/test_models_classification_tf.py +++ b/tests/tensorflow/test_models_classification_tf.py @@ -2,6 +2,7 @@ import tempfile import cv2 +import keras import numpy as np import onnxruntime import psutil @@ -37,7 +38,7 @@ def test_classification_architectures(arch_name, input_shape, output_size): # Model batch_size = 2 - tf.keras.backend.clear_session() + keras.backend.clear_session() model = classification.__dict__[arch_name](pretrained=True, include_top=True, input_shape=input_shape) # Forward out = model(tf.random.uniform(shape=[batch_size, *input_shape], maxval=1, dtype=tf.float32)) @@ -46,7 +47,7 @@ def test_classification_architectures(arch_name, input_shape, output_size): assert out.dtype == tf.float32 assert out.numpy().shape == (batch_size, *output_size) # Check that you can load pretrained up to the classification layer with differing number of classes to fine-tune - tf.keras.backend.clear_session() + keras.backend.clear_session() assert classification.__dict__[arch_name]( pretrained=True, include_top=True, input_shape=input_shape, num_classes=10 ) @@ -62,7 +63,7 @@ def test_classification_architectures(arch_name, input_shape, output_size): def test_classification_models(arch_name, input_shape): batch_size = 8 reco_model = classification.__dict__[arch_name](pretrained=True, input_shape=input_shape) - assert isinstance(reco_model, tf.keras.Model) + assert isinstance(reco_model, keras.Model) input_tensor = tf.random.uniform(shape=[batch_size, *input_shape], minval=0, maxval=1) out = reco_model(input_tensor) @@ -231,7 +232,7 @@ def test_page_orientation_model(mock_payslip): def test_models_onnx_export(arch_name, input_shape, output_size): # Model batch_size = 2 - tf.keras.backend.clear_session() + keras.backend.clear_session() if "orientation" in arch_name: model = classification.__dict__[arch_name](pretrained=True, input_shape=input_shape) else: