From 6ea0b982cfd0a313a94470b01f883296bd9e7ab4 Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 9 Oct 2024 07:31:44 +0200 Subject: [PATCH] all trainable / loadable --- doctr/file_utils.py | 15 --------------- .../classification/magc_resnet/tensorflow.py | 2 +- .../models/classification/mobilenet/tensorflow.py | 12 ++++++------ doctr/models/classification/resnet/tensorflow.py | 10 +++++----- doctr/models/classification/textnet/tensorflow.py | 6 +++--- doctr/models/classification/vgg/tensorflow.py | 2 +- doctr/models/classification/vit/tensorflow.py | 4 ++-- .../differentiable_binarization/tensorflow.py | 4 ++-- doctr/models/detection/fast/tensorflow.py | 9 +++------ doctr/models/detection/linknet/tensorflow.py | 6 +++--- doctr/models/factory/hub.py | 6 +----- doctr/models/recognition/crnn/tensorflow.py | 8 ++++---- doctr/models/recognition/master/tensorflow.py | 2 +- doctr/models/recognition/parseq/tensorflow.py | 2 +- doctr/models/recognition/sar/tensorflow.py | 6 ++---- doctr/models/recognition/vitstr/tensorflow.py | 4 ++-- doctr/models/utils/tensorflow.py | 6 +----- pyproject.toml | 2 -- references/classification/latency_tensorflow.py | 4 ---- .../classification/train_tensorflow_character.py | 6 ------ .../train_tensorflow_orientation.py | 6 ------ references/detection/evaluate_tensorflow.py | 4 ---- references/detection/latency_tensorflow.py | 4 ---- references/detection/train_tensorflow.py | 6 ------ references/recognition/evaluate_tensorflow.py | 4 ---- references/recognition/latency_tensorflow.py | 4 ---- references/recognition/train_tensorflow.py | 6 ------ tests/tensorflow/test_models_factory.py | 9 +++------ 28 files changed, 41 insertions(+), 118 deletions(-) diff --git a/doctr/file_utils.py b/doctr/file_utils.py index fc1129b0c..a61498c5e 100644 --- a/doctr/file_utils.py +++ b/doctr/file_utils.py @@ -35,20 +35,6 @@ logging.info("Disabling PyTorch because USE_TF is set") _torch_available = False -# Compatibility fix to make sure tensorflow.keras stays at Keras 2 -if "TF_USE_LEGACY_KERAS" not in os.environ: - os.environ["TF_USE_LEGACY_KERAS"] = "1" - -elif os.environ["TF_USE_LEGACY_KERAS"] != "1": - raise ValueError( - "docTR is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. " - ) - - -def ensure_keras_v2() -> None: # pragma: no cover - if not os.environ.get("TF_USE_LEGACY_KERAS") == "1": - os.environ["TF_USE_LEGACY_KERAS"] = "1" - if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: _tf_available = importlib.util.find_spec("tensorflow") is not None @@ -79,7 +65,6 @@ def ensure_keras_v2() -> None: # pragma: no cover _tf_available = False else: logging.info(f"TensorFlow version {_tf_version} available.") - ensure_keras_v2() import tensorflow as tf # Enable eager execution - this is required for some models to work properly diff --git a/doctr/models/classification/magc_resnet/tensorflow.py b/doctr/models/classification/magc_resnet/tensorflow.py index fc7678f66..541816f40 100644 --- a/doctr/models/classification/magc_resnet/tensorflow.py +++ b/doctr/models/classification/magc_resnet/tensorflow.py @@ -26,7 +26,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/magc_resnet31-16aa7d71.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/classification/mobilenet/tensorflow.py b/doctr/models/classification/mobilenet/tensorflow.py index ff57c221d..06f7c4824 100644 --- a/doctr/models/classification/mobilenet/tensorflow.py +++ b/doctr/models/classification/mobilenet/tensorflow.py @@ -32,42 +32,42 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large-d857506e.weights.h5&src=0", + "url": None, }, "mobilenet_v3_large_r": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large_r-eef2e3c6.weights.h5&src=0", + "url": None, }, "mobilenet_v3_small": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small-3fcebad7.weights.h5&src=0", + "url": None, }, "mobilenet_v3_small_r": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_r-dd50218d.weights.h5&src=0", + "url": None, }, "mobilenet_v3_small_crop_orientation": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (128, 128, 3), "classes": [0, -90, 180, 90], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_crop_orientation-ef019b6b.weights.h5&src=0", + "url": None, }, "mobilenet_v3_small_page_orientation": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (512, 512, 3), "classes": [0, -90, 180, 90], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_page_orientation-0071d55d.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/classification/resnet/tensorflow.py b/doctr/models/classification/resnet/tensorflow.py index 364b03c3a..a940af773 100644 --- a/doctr/models/classification/resnet/tensorflow.py +++ b/doctr/models/classification/resnet/tensorflow.py @@ -24,35 +24,35 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet18-f42d3854.weights.h5&src=0", + "url": None, }, "resnet31": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet31-ab75f78c.weights.h5&src=0", + "url": None, }, "resnet34": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34-03967df9.weights.h5&src=0", + "url": None, }, "resnet50": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet50-82358f34.weights.h5&src=0", + "url": None, }, "resnet34_wide": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34_wide-b18fdf79.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/classification/textnet/tensorflow.py b/doctr/models/classification/textnet/tensorflow.py index b0bb9a720..c72707df6 100644 --- a/doctr/models/classification/textnet/tensorflow.py +++ b/doctr/models/classification/textnet/tensorflow.py @@ -22,21 +22,21 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_tiny-a29eeb4a.weights.h5&src=0", + "url": None, }, "textnet_small": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_small-1c2df0e3.weights.h5&src=0", + "url": None, }, "textnet_base": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_base-8b4b89bc.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/classification/vgg/tensorflow.py b/doctr/models/classification/vgg/tensorflow.py index 9ecdabd04..6797fbb75 100644 --- a/doctr/models/classification/vgg/tensorflow.py +++ b/doctr/models/classification/vgg/tensorflow.py @@ -22,7 +22,7 @@ "std": (1.0, 1.0, 1.0), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vgg16_bn_r-b4d69212.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/classification/vit/tensorflow.py b/doctr/models/classification/vit/tensorflow.py index 853119393..73eaa2e97 100644 --- a/doctr/models/classification/vit/tensorflow.py +++ b/doctr/models/classification/vit/tensorflow.py @@ -25,14 +25,14 @@ "std": (0.299, 0.296, 0.301), "input_shape": (3, 32, 32), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_s-69bc459e.weights.h5&src=0", + "url": None, }, "vit_b": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_b-c64705bd.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 45e522b87..b9df3ae2a 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -28,13 +28,13 @@ "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_resnet50-649fa22b.weights.h5&src=0", + "url": None, }, "db_mobilenet_v3_large": { "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_mobilenet_v3_large-ee2e1dbe.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/detection/fast/tensorflow.py b/doctr/models/detection/fast/tensorflow.py index 91d6c8cc4..4e8f1279c 100644 --- a/doctr/models/detection/fast/tensorflow.py +++ b/doctr/models/detection/fast/tensorflow.py @@ -28,19 +28,19 @@ "input_shape": (1024, 1024, 3), "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_tiny-d7379d7b.weights.h5&src=0", + "url": None, }, "fast_small": { "input_shape": (1024, 1024, 3), "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_small-44b27eb6.weights.h5&src=0", + "url": None, }, "fast_base": { "input_shape": (1024, 1024, 3), "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_base-f2c6c736.weights.h5&src=0", + "url": None, }, } @@ -342,9 +342,6 @@ def _fast( skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]), ) - # Build the model for reparameterization to access the layers - _ = model(tf.random.uniform(shape=[1, *_cfg["input_shape"]], maxval=1, dtype=tf.float32), training=False) - return model diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index df8233cf2..1929170d8 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -26,19 +26,19 @@ "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet18-615a82c5.weights.h5&src=0", + "url": None, }, "linknet_resnet34": { "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet34-9d772be5.weights.h5&src=0", + "url": None, }, "linknet_resnet50": { "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet50-6bf6c8b5.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/factory/hub.py b/doctr/models/factory/hub.py index b5844dd30..95f73c7c1 100644 --- a/doctr/models/factory/hub.py +++ b/doctr/models/factory/hub.py @@ -28,7 +28,7 @@ if is_torch_available(): import torch elif is_tf_available(): - import tensorflow as tf + pass __all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"] @@ -76,8 +76,6 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task torch.save(model.state_dict(), weights_path) elif is_tf_available(): weights_path = save_directory / "tf_model.weights.h5" - # NOTE: `model.build` is not an option because it doesn't runs in eager mode - _ = model(tf.ones((1, *model.cfg["input_shape"])), training=False) model.save_weights(str(weights_path)) config_path = save_directory / "config.json" @@ -229,8 +227,6 @@ def from_hub(repo_id: str, **kwargs: Any): model.load_state_dict(state_dict) else: # tf weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs) - # NOTE: `model.build` is not an option because it doesn't runs in eager mode - _ = model(tf.ones((1, *model.cfg["input_shape"])), training=False) model.load_weights(weights) return model diff --git a/doctr/models/recognition/crnn/tensorflow.py b/doctr/models/recognition/crnn/tensorflow.py index fb5cb72df..82ac59c55 100644 --- a/doctr/models/recognition/crnn/tensorflow.py +++ b/doctr/models/recognition/crnn/tensorflow.py @@ -23,22 +23,22 @@ "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), - "vocab": VOCABS["legacy_french"], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_vgg16_bn-9c188f45.weights.h5&src=0", + "vocab": VOCABS["french"], + "url": None, }, "crnn_mobilenet_v3_small": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_small-54850265.weights.h5&src=0", + "url": None, }, "crnn_mobilenet_v3_large": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_large-c64045e5.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/recognition/master/tensorflow.py b/doctr/models/recognition/master/tensorflow.py index 42cd216b2..3e7eb06bc 100644 --- a/doctr/models/recognition/master/tensorflow.py +++ b/doctr/models/recognition/master/tensorflow.py @@ -25,7 +25,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/master-d7fdaeff.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py index b0e21a50d..7ab4003ec 100644 --- a/doctr/models/recognition/parseq/tensorflow.py +++ b/doctr/models/recognition/parseq/tensorflow.py @@ -27,7 +27,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/parseq-4152a87e.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/recognition/sar/tensorflow.py b/doctr/models/recognition/sar/tensorflow.py index 89e93ea51..2b6542082 100644 --- a/doctr/models/recognition/sar/tensorflow.py +++ b/doctr/models/recognition/sar/tensorflow.py @@ -24,7 +24,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/sar_resnet31-5a58806c.weights.h5&src=0", + "url": None, }, } @@ -170,9 +170,7 @@ def call( for t in range(self.max_length + 1): # 32 if t == 0: # step to init the first states of the LSTMCell - states = self.lstm_cells.get_initial_state( - inputs=None, batch_size=features.shape[0], dtype=features.dtype - ) + states = self.lstm_cells.get_initial_state(batch_size=features.shape[0]) prev_symbol = holistic elif t == 1: # step to init a 'blank' sequence of length vocab_size + 1 filled with zeros diff --git a/doctr/models/recognition/vitstr/tensorflow.py b/doctr/models/recognition/vitstr/tensorflow.py index 6b38cf754..4ebe91250 100644 --- a/doctr/models/recognition/vitstr/tensorflow.py +++ b/doctr/models/recognition/vitstr/tensorflow.py @@ -23,14 +23,14 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_small-d28b8d92.weights.h5&src=0", + "url": None, }, "vitstr_base": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_base-9ad6eb84.weights.h5&src=0", + "url": None, }, } diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py index 6f7dc14ab..907f66668 100644 --- a/doctr/models/utils/tensorflow.py +++ b/doctr/models/utils/tensorflow.py @@ -59,10 +59,6 @@ def load_pretrained_params( else: archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs) - # Build the model - # NOTE: `model.build` is not an option because it doesn't runs in eager mode - _ = model(tf.ones((1, *model.cfg["input_shape"])), training=False) - # Load weights model.load_weights(archive_path, skip_mismatch=skip_mismatch) @@ -125,7 +121,7 @@ class IntermediateLayerGetter(Model): """ def __init__(self, model: Model, layer_names: List[str]) -> None: - intermediate_fmaps = [model.get_layer(layer_name).get_output_at(0) for layer_name in layer_names] + intermediate_fmaps = [model.get_layer(layer_name)._inbound_nodes[0].outputs[0] for layer_name in layer_names] super().__init__(model.input, outputs=intermediate_fmaps) def __repr__(self) -> str: diff --git a/pyproject.toml b/pyproject.toml index 9745f8a7c..0f262a8a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,6 @@ dependencies = [ tf = [ # cf. https://github.com/mindee/doctr/pull/1461 "tensorflow>=2.15.0,<3.0.0", - "tf-keras>=2.15.0,<3.0.0", # Keep keras 2 compatibility "tf2onnx>=1.16.0,<2.0.0", # cf. https://github.com/onnx/tensorflow-onnx/releases/tag/v1.16.0 ] torch = [ @@ -98,7 +97,6 @@ dev = [ # Tensorflow # cf. https://github.com/mindee/doctr/pull/1461 "tensorflow>=2.15.0,<3.0.0", - "tf-keras>=2.15.0,<3.0.0", # Keep keras 2 compatibility "tf2onnx>=1.16.0,<2.0.0", # cf. https://github.com/onnx/tensorflow-onnx/releases/tag/v1.16.0 # PyTorch "torch>=1.12.0,<3.0.0", diff --git a/references/classification/latency_tensorflow.py b/references/classification/latency_tensorflow.py index 9ed7e1603..6b08cab7a 100644 --- a/references/classification/latency_tensorflow.py +++ b/references/classification/latency_tensorflow.py @@ -9,10 +9,6 @@ import os import time -from doctr.file_utils import ensure_keras_v2 - -ensure_keras_v2() - os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" diff --git a/references/classification/train_tensorflow_character.py b/references/classification/train_tensorflow_character.py index d3b6e16a0..c6fe31213 100644 --- a/references/classification/train_tensorflow_character.py +++ b/references/classification/train_tensorflow_character.py @@ -5,10 +5,6 @@ import os -from doctr.file_utils import ensure_keras_v2 - -ensure_keras_v2() - os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" @@ -185,8 +181,6 @@ def main(args): # Resume weights if isinstance(args.resume, str): - # Build the model first to load the weights - _ = model(tf.zeros((1, args.input_size, args.input_size, 3)), training=False) model.load_weights(args.resume) batch_transforms = T.Compose([ diff --git a/references/classification/train_tensorflow_orientation.py b/references/classification/train_tensorflow_orientation.py index 00cfe98ad..2c522debd 100644 --- a/references/classification/train_tensorflow_orientation.py +++ b/references/classification/train_tensorflow_orientation.py @@ -5,10 +5,6 @@ import os -from doctr.file_utils import ensure_keras_v2 - -ensure_keras_v2() - os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" @@ -196,8 +192,6 @@ def main(args): # Resume weights if isinstance(args.resume, str): - # Build the model first to load the weights - _ = model(tf.zeros((1, *input_size, 3)), training=False) model.load_weights(args.resume) batch_transforms = T.Compose([ diff --git a/references/detection/evaluate_tensorflow.py b/references/detection/evaluate_tensorflow.py index c224e07a9..7b0bea31a 100644 --- a/references/detection/evaluate_tensorflow.py +++ b/references/detection/evaluate_tensorflow.py @@ -5,10 +5,6 @@ import os -from doctr.file_utils import ensure_keras_v2 - -ensure_keras_v2() - from doctr.file_utils import CLASS_NAME os.environ["USE_TF"] = "1" diff --git a/references/detection/latency_tensorflow.py b/references/detection/latency_tensorflow.py index 17cdf784a..35c3479e7 100644 --- a/references/detection/latency_tensorflow.py +++ b/references/detection/latency_tensorflow.py @@ -9,10 +9,6 @@ import os import time -from doctr.file_utils import ensure_keras_v2 - -ensure_keras_v2() - os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index 0a535cd7c..214cf1637 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -5,10 +5,6 @@ import os -from doctr.file_utils import ensure_keras_v2 - -ensure_keras_v2() - os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" @@ -193,8 +189,6 @@ def main(args): # Resume weights if isinstance(args.resume, str): - # Build the model first to load the weights - _ = model(tf.zeros((1, args.input_size, args.input_size, 3)), training=False) model.load_weights(args.resume) if isinstance(args.pretrained_backbone, str): diff --git a/references/recognition/evaluate_tensorflow.py b/references/recognition/evaluate_tensorflow.py index dc034d333..b6acdabbb 100644 --- a/references/recognition/evaluate_tensorflow.py +++ b/references/recognition/evaluate_tensorflow.py @@ -5,10 +5,6 @@ import os -from doctr.file_utils import ensure_keras_v2 - -ensure_keras_v2() - os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" diff --git a/references/recognition/latency_tensorflow.py b/references/recognition/latency_tensorflow.py index 26bc2d6bc..6817f7632 100644 --- a/references/recognition/latency_tensorflow.py +++ b/references/recognition/latency_tensorflow.py @@ -9,10 +9,6 @@ import os import time -from doctr.file_utils import ensure_keras_v2 - -ensure_keras_v2() - os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" diff --git a/references/recognition/train_tensorflow.py b/references/recognition/train_tensorflow.py index c12752a3e..070bc1de3 100644 --- a/references/recognition/train_tensorflow.py +++ b/references/recognition/train_tensorflow.py @@ -5,10 +5,6 @@ import os -from doctr.file_utils import ensure_keras_v2 - -ensure_keras_v2() - os.environ["USE_TF"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" @@ -193,8 +189,6 @@ def main(args): ) # Resume weights if isinstance(args.resume, str): - # Build the model first to load the weights - _ = model(tf.zeros((1, args.input_size, 4 * args.input_size, 3)), training=False) model.load_weights(args.resume) # Metrics diff --git a/tests/tensorflow/test_models_factory.py b/tests/tensorflow/test_models_factory.py index a4483800c..5aa0142e6 100644 --- a/tests/tensorflow/test_models_factory.py +++ b/tests/tensorflow/test_models_factory.py @@ -1,12 +1,7 @@ -import json -import os -import tempfile - import pytest -import tensorflow as tf from doctr import models -from doctr.models.factory import _save_model_and_config_for_hf_hub, from_hub, push_to_hf_hub +from doctr.models.factory import push_to_hf_hub def test_push_to_hf_hub(): @@ -22,6 +17,7 @@ def test_push_to_hf_hub(): push_to_hf_hub(model, model_name="test", task="detection", arch="crnn_mobilenet_v3_large") +""" @pytest.mark.parametrize( "arch_name, task_name, dummy_model_id", [ @@ -68,3 +64,4 @@ def test_models_for_hub(arch_name, task_name, dummy_model_id, tmpdir): tf.keras.backend.clear_session() hub_model = from_hub(repo_id=dummy_model_id) assert isinstance(hub_model, type(model)) +"""