diff --git a/docs/source/using_doctr/custom_models_training.rst b/docs/source/using_doctr/custom_models_training.rst index ecf88d8116..13e4640a36 100644 --- a/docs/source/using_doctr/custom_models_training.rst +++ b/docs/source/using_doctr/custom_models_training.rst @@ -22,19 +22,19 @@ This section shows how you can easily load a custom trained model in docTR. # Load custom detection model det_model = db_resnet50(pretrained=False, pretrained_backbone=False) - det_model.load_weights("/weights") + det_model.load_weights("") predictor = ocr_predictor(det_arch=det_model, reco_arch="vitstr_small", pretrained=True) # Load custom recognition model reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False) - reco_model.load_weights("/weights") + reco_model.load_weights("") predictor = ocr_predictor(det_arch="linknet_resnet18", reco_arch=reco_model, pretrained=True) # Load custom detection and recognition model det_model = db_resnet50(pretrained=False, pretrained_backbone=False) - det_model.load_weights("/weights") + det_model.load_weights("") reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False) - reco_model.load_weights("/weights") + reco_model.load_weights("") predictor = ocr_predictor(det_arch=det_model, reco_arch=reco_model, pretrained=False) .. tab:: PyTorch @@ -77,7 +77,7 @@ Load a custom recognition model trained on another vocabulary as the default one from doctr.datasets import VOCABS reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False, vocab=VOCABS["german"]) - reco_model.load_weights("/weights") + reco_model.load_weights("") predictor = ocr_predictor(det_arch='linknet_resnet18', reco_arch=reco_model, pretrained=True) @@ -106,7 +106,7 @@ Load a custom trained KIE detection model: from doctr.models import kie_predictor, db_resnet50 det_model = db_resnet50(pretrained=False, pretrained_backbone=False, class_names=['total', 'date']) - det_model.load_weights("/weights") + det_model.load_weights("") kie_predictor(det_arch=det_model, reco_arch='crnn_vgg16_bn', pretrained=True) .. tab:: PyTorch @@ -136,9 +136,9 @@ Load a model with customized Preprocessor: from doctr.models import db_resnet50, crnn_vgg16_bn det_model = db_resnet50(pretrained=False, pretrained_backbone=False) - det_model.load_weights("/weights") + det_model.load_weights("") reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False) - reco_model.load_weights("/weights") + reco_model.load_weights("") det_predictor = DetectionPredictor( PreProcessor( @@ -233,9 +233,9 @@ Loading your custom trained orientation classification model from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=False) - custom_page_orientation_model.load_weights("/weights") + custom_page_orientation_model.load_weights("") custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=False) - custom_crop_orientation_model.load_weights("/weights") + custom_crop_orientation_model.load_weights("") predictor = ocr_predictor( pretrained=True, diff --git a/docs/source/using_doctr/using_model_export.rst b/docs/source/using_doctr/using_model_export.rst index c62c36169b..48f570f699 100644 --- a/docs/source/using_doctr/using_model_export.rst +++ b/docs/source/using_doctr/using_model_export.rst @@ -31,7 +31,7 @@ Advantages: .. code:: python3 import tensorflow as tf - from tensorflow.keras import mixed_precision + from keras import mixed_precision mixed_precision.set_global_policy('mixed_float16') predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True) diff --git a/doctr/io/image/tensorflow.py b/doctr/io/image/tensorflow.py index 28fb2fadd5..3b1f1ed0e2 100644 --- a/doctr/io/image/tensorflow.py +++ b/doctr/io/image/tensorflow.py @@ -7,8 +7,8 @@ import numpy as np import tensorflow as tf +from keras.utils import img_to_array from PIL import Image -from tensorflow.keras.utils import img_to_array from doctr.utils.common_types import AbstractPath diff --git a/doctr/models/classification/magc_resnet/tensorflow.py b/doctr/models/classification/magc_resnet/tensorflow.py index e791e661bf..12f7c6beea 100644 --- a/doctr/models/classification/magc_resnet/tensorflow.py +++ b/doctr/models/classification/magc_resnet/tensorflow.py @@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional, Tuple import tensorflow as tf -from tensorflow.keras import layers -from tensorflow.keras.models import Sequential +from keras import activations, layers +from keras.models import Sequential from doctr.datasets import VOCABS @@ -26,7 +26,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/magc_resnet31-addbb705.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/magc_resnet31-16aa7d71.weights.h5&src=0", }, } @@ -57,6 +57,7 @@ def __init__( self.headers = headers # h self.inplanes = inplanes # C self.attn_scale = attn_scale + self.ratio = ratio self.planes = int(inplanes * ratio) self.single_header_inplanes = int(inplanes / headers) # C / h @@ -97,7 +98,7 @@ def context_modeling(self, inputs: tf.Tensor) -> tf.Tensor: if self.attn_scale and self.headers > 1: context_mask = context_mask / math.sqrt(self.single_header_inplanes) # B*h, 1, H*W, 1 - context_mask = tf.keras.activations.softmax(context_mask, axis=2) + context_mask = activations.softmax(context_mask, axis=2) # Compute context # B*h, 1, C/h, 1 @@ -153,7 +154,11 @@ def _magc_resnet( ) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The number of classes is not the same as the number of classes in the pretrained model => + # skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) + ) return model diff --git a/doctr/models/classification/mobilenet/tensorflow.py b/doctr/models/classification/mobilenet/tensorflow.py index 2156cf1f50..6250abc666 100644 --- a/doctr/models/classification/mobilenet/tensorflow.py +++ b/doctr/models/classification/mobilenet/tensorflow.py @@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union import tensorflow as tf -from tensorflow.keras import layers -from tensorflow.keras.models import Sequential +from keras import layers +from keras.models import Sequential from ....datasets import VOCABS from ...utils import conv_sequence, load_pretrained_params @@ -32,42 +32,42 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large-47d25d7e.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large-d857506e.weights.h5&src=0", }, "mobilenet_v3_large_r": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large_r-a108e192.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large_r-eef2e3c6.weights.h5&src=0", }, "mobilenet_v3_small": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small-8a32c32c.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small-3fcebad7.weights.h5&src=0", }, "mobilenet_v3_small_r": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-3d61452e.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_r-dd50218d.weights.h5&src=0", }, "mobilenet_v3_small_crop_orientation": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (128, 128, 3), "classes": [0, -90, 180, 90], - "url": "https://doctr-static.mindee.com/models?id=v0.4.1/classif_mobilenet_v3_small-1ea8db03.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_crop_orientation-ef019b6b.weights.h5&src=0", }, "mobilenet_v3_small_page_orientation": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (512, 512, 3), "classes": [0, -90, 180, 90], - "url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_page_orientation-aec9553e.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_page_orientation-0071d55d.weights.h5&src=0", }, } @@ -297,7 +297,11 @@ def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwa ) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The number of classes is not the same as the number of classes in the pretrained model => + # skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) + ) return model diff --git a/doctr/models/classification/predictor/tensorflow.py b/doctr/models/classification/predictor/tensorflow.py index e3756e6e83..ba26e1db54 100644 --- a/doctr/models/classification/predictor/tensorflow.py +++ b/doctr/models/classification/predictor/tensorflow.py @@ -7,7 +7,7 @@ import numpy as np import tensorflow as tf -from tensorflow import keras +from keras import Model from doctr.models.preprocessor import PreProcessor from doctr.utils.repr import NestedObject @@ -30,10 +30,10 @@ class OrientationPredictor(NestedObject): def __init__( self, pre_processor: Optional[PreProcessor], - model: Optional[keras.Model], + model: Optional[Model], ) -> None: self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None - self.model = model if isinstance(model, keras.Model) else None + self.model = model if isinstance(model, Model) else None def __call__( self, diff --git a/doctr/models/classification/resnet/tensorflow.py b/doctr/models/classification/resnet/tensorflow.py index 7648e5f8d0..3e78ae0ae2 100644 --- a/doctr/models/classification/resnet/tensorflow.py +++ b/doctr/models/classification/resnet/tensorflow.py @@ -7,9 +7,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import tensorflow as tf -from tensorflow.keras import layers -from tensorflow.keras.applications import ResNet50 -from tensorflow.keras.models import Sequential +from keras import layers +from keras.applications import ResNet50 +from keras.models import Sequential from doctr.datasets import VOCABS @@ -24,35 +24,35 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.4.1/resnet18-d4634669.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet18-f42d3854.weights.h5&src=0", }, "resnet31": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet31-5a47a60b.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet31-ab75f78c.weights.h5&src=0", }, "resnet34": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34-5dcc97ca.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34-03967df9.weights.h5&src=0", }, "resnet50": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet50-e75e4cdf.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet50-82358f34.weights.h5&src=0", }, "resnet34_wide": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34_wide-c1271816.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34_wide-b18fdf79.weights.h5&src=0", }, } @@ -212,7 +212,11 @@ def _resnet( ) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The number of classes is not the same as the number of classes in the pretrained model => + # skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) + ) return model @@ -357,7 +361,13 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet: # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs["resnet50"]["url"]) + # The number of classes is not the same as the number of classes in the pretrained model => + # skip the mismatching layers for fine tuning + load_pretrained_params( + model, + default_cfgs["resnet50"]["url"], + skip_mismatch=kwargs["num_classes"] != len(default_cfgs["resnet50"]["classes"]), + ) return model diff --git a/doctr/models/classification/textnet/tensorflow.py b/doctr/models/classification/textnet/tensorflow.py index f30d5d823c..3d79b15f09 100644 --- a/doctr/models/classification/textnet/tensorflow.py +++ b/doctr/models/classification/textnet/tensorflow.py @@ -7,7 +7,7 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple -from tensorflow.keras import Sequential, layers +from keras import Sequential, layers from doctr.datasets import VOCABS @@ -22,21 +22,21 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_tiny-fe9cc245.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_tiny-a29eeb4a.weights.h5&src=0", }, "textnet_small": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_small-29c39c82.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_small-1c2df0e3.weights.h5&src=0", }, "textnet_base": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_base-168aa82c.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_base-8b4b89bc.weights.h5&src=0", }, } @@ -113,7 +113,11 @@ def _textnet( model = TextNet(cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The number of classes is not the same as the number of classes in the pretrained model => + # skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) + ) return model diff --git a/doctr/models/classification/vgg/tensorflow.py b/doctr/models/classification/vgg/tensorflow.py index 259ed9f888..d9e7bb374b 100644 --- a/doctr/models/classification/vgg/tensorflow.py +++ b/doctr/models/classification/vgg/tensorflow.py @@ -6,8 +6,8 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple -from tensorflow.keras import layers -from tensorflow.keras.models import Sequential +from keras import layers +from keras.models import Sequential from doctr.datasets import VOCABS @@ -22,7 +22,7 @@ "std": (1.0, 1.0, 1.0), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.4.1/vgg16_bn_r-c5836cea.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vgg16_bn_r-b4d69212.weights.h5&src=0", }, } @@ -83,7 +83,11 @@ def _vgg( model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The number of classes is not the same as the number of classes in the pretrained model => + # skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) + ) return model diff --git a/doctr/models/classification/vit/tensorflow.py b/doctr/models/classification/vit/tensorflow.py index 4b73b49ac9..28ff2e244e 100644 --- a/doctr/models/classification/vit/tensorflow.py +++ b/doctr/models/classification/vit/tensorflow.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple import tensorflow as tf -from tensorflow.keras import Sequential, layers +from keras import Sequential, layers from doctr.datasets import VOCABS from doctr.models.modules.transformer import EncoderBlock @@ -25,14 +25,14 @@ "std": (0.299, 0.296, 0.301), "input_shape": (3, 32, 32), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_s-6300fcc9.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_s-69bc459e.weights.h5&src=0", }, "vit_b": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 32, 3), "classes": list(VOCABS["french"]), - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_b-57158446.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_b-c64705bd.weights.h5&src=0", }, } @@ -123,7 +123,11 @@ def _vit( model = VisionTransformer(cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The number of classes is not the same as the number of classes in the pretrained model => + # skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) + ) return model diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index df9935b042..7fdbd43ce0 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -10,9 +10,8 @@ import numpy as np import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import layers -from tensorflow.keras.applications import ResNet50 +from keras import Model, Sequential, layers, losses +from keras.applications import ResNet50 from doctr.file_utils import CLASS_NAME from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params @@ -29,13 +28,13 @@ "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet50-84171458.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_resnet50-649fa22b.weights.h5&src=0", }, "db_mobilenet_v3_large": { "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_mobilenet_v3_large-da524564.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_mobilenet_v3_large-ee2e1dbe.weights.h5&src=0", }, } @@ -81,7 +80,7 @@ def build_upsampling( if dilation_factor > 1: _layers.append(layers.UpSampling2D(size=(dilation_factor, dilation_factor), interpolation="nearest")) - module = keras.Sequential(_layers) + module = Sequential(_layers) return module @@ -104,7 +103,7 @@ def call( return layers.concatenate(results) -class DBNet(_DBNet, keras.Model, NestedObject): +class DBNet(_DBNet, Model, NestedObject): """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" `_. @@ -147,14 +146,14 @@ def __init__( _inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape] output_shape = tuple(self.fpn(_inputs).shape) - self.probability_head = keras.Sequential([ + self.probability_head = Sequential([ *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]), layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"), layers.BatchNormalization(), layers.Activation("relu"), layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer="he_normal"), ]) - self.threshold_head = keras.Sequential([ + self.threshold_head = Sequential([ *conv_sequence(64, "relu", True, kernel_size=3, input_shape=output_shape[1:]), layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer="he_normal"), layers.BatchNormalization(), @@ -206,7 +205,7 @@ def compute_loss( # Focal loss focal_scale = 10.0 - bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True) + bce_loss = losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True) # Convert logits to prob, compute gamma factor p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map)) @@ -307,7 +306,12 @@ def _db_resnet( model = DBNet(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, _cfg["url"]) + # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning + load_pretrained_params( + model, + _cfg["url"], + skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]), + ) return model @@ -326,6 +330,10 @@ def _db_mobilenet( # Patch the config _cfg = deepcopy(default_cfgs[arch]) _cfg["input_shape"] = input_shape or _cfg["input_shape"] + if not kwargs.get("class_names", None): + kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME]) + else: + kwargs["class_names"] = sorted(kwargs["class_names"]) # Feature extractor feat_extractor = IntermediateLayerGetter( @@ -341,7 +349,12 @@ def _db_mobilenet( model = DBNet(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, _cfg["url"]) + # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning + load_pretrained_params( + model, + _cfg["url"], + skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]), + ) return model diff --git a/doctr/models/detection/fast/tensorflow.py b/doctr/models/detection/fast/tensorflow.py index 69998a2303..80fc31fea3 100644 --- a/doctr/models/detection/fast/tensorflow.py +++ b/doctr/models/detection/fast/tensorflow.py @@ -10,8 +10,7 @@ import numpy as np import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import Sequential, layers +from keras import Model, Sequential, layers from doctr.file_utils import CLASS_NAME from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, load_pretrained_params @@ -29,19 +28,19 @@ "input_shape": (1024, 1024, 3), "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), - "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_tiny-959daecb.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_tiny-d7379d7b.weights.h5&src=0", }, "fast_small": { "input_shape": (1024, 1024, 3), "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), - "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_small-f1617503.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_small-44b27eb6.weights.h5&src=0", }, "fast_base": { "input_shape": (1024, 1024, 3), "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), - "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_base-255e2ac3.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_base-f2c6c736.weights.h5&src=0", }, } @@ -100,7 +99,7 @@ def __init__( super().__init__(_layers) -class FAST(_FAST, keras.Model, NestedObject): +class FAST(_FAST, Model, NestedObject): """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" `_. @@ -336,7 +335,12 @@ def _fast( model = FAST(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, _cfg["url"]) + # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning + load_pretrained_params( + model, + _cfg["url"], + skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]), + ) # Build the model for reparameterization to access the layers _ = model(tf.random.uniform(shape=[1, *_cfg["input_shape"]], maxval=1, dtype=tf.float32), training=False) diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index ff11dbe477..683c49373a 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -10,8 +10,7 @@ import numpy as np import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import Model, Sequential, layers +from keras import Model, Sequential, layers, losses from doctr.file_utils import CLASS_NAME from doctr.models.classification import resnet18, resnet34, resnet50 @@ -27,19 +26,19 @@ "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet18-b9ee56e6.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet18-615a82c5.weights.h5&src=0", }, "linknet_resnet34": { "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet34-51909c56.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet34-9d772be5.weights.h5&src=0", }, "linknet_resnet50": { "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet50-ac9f3829.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet50-6bf6c8b5.weights.h5&src=0", }, } @@ -90,7 +89,7 @@ def extra_repr(self) -> str: return f"out_chans={self.out_chans}" -class LinkNet(_LinkNet, keras.Model): +class LinkNet(_LinkNet, Model): """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" `_. @@ -187,7 +186,7 @@ def compute_loss( seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool) seg_mask = tf.cast(seg_mask, tf.float32) - bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True) + bce_loss = losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True) proba_map = tf.sigmoid(out_map) # Focal loss @@ -277,7 +276,12 @@ def _linknet( model = LinkNet(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, _cfg["url"]) + # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning + load_pretrained_params( + model, + _cfg["url"], + skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]), + ) return model diff --git a/doctr/models/detection/predictor/tensorflow.py b/doctr/models/detection/predictor/tensorflow.py index 14f38172df..a7ccd4a9ac 100644 --- a/doctr/models/detection/predictor/tensorflow.py +++ b/doctr/models/detection/predictor/tensorflow.py @@ -7,7 +7,7 @@ import numpy as np import tensorflow as tf -from tensorflow import keras +from keras import Model from doctr.models.detection._utils import _remove_padding from doctr.models.preprocessor import PreProcessor @@ -30,7 +30,7 @@ class DetectionPredictor(NestedObject): def __init__( self, pre_processor: PreProcessor, - model: keras.Model, + model: Model, ) -> None: self.pre_processor = pre_processor self.model = model diff --git a/doctr/models/factory/hub.py b/doctr/models/factory/hub.py index 41cd91579a..b5844dd30b 100644 --- a/doctr/models/factory/hub.py +++ b/doctr/models/factory/hub.py @@ -20,7 +20,6 @@ get_token_permission, hf_hub_download, login, - snapshot_download, ) from doctr import models @@ -28,6 +27,8 @@ if is_torch_available(): import torch +elif is_tf_available(): + import tensorflow as tf __all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"] @@ -74,7 +75,9 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task weights_path = save_directory / "pytorch_model.bin" torch.save(model.state_dict(), weights_path) elif is_tf_available(): - weights_path = save_directory / "tf_model" / "weights" + weights_path = save_directory / "tf_model.weights.h5" + # NOTE: `model.build` is not an option because it doesn't runs in eager mode + _ = model(tf.ones((1, *model.cfg["input_shape"])), training=False) model.save_weights(str(weights_path)) config_path = save_directory / "config.json" @@ -225,7 +228,9 @@ def from_hub(repo_id: str, **kwargs: Any): state_dict = torch.load(hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs), map_location="cpu") model.load_state_dict(state_dict) else: # tf - repo_path = snapshot_download(repo_id, **kwargs) - model.load_weights(os.path.join(repo_path, "tf_model", "weights")) + weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs) + # NOTE: `model.build` is not an option because it doesn't runs in eager mode + _ = model(tf.ones((1, *model.cfg["input_shape"])), training=False) + model.load_weights(weights) return model diff --git a/doctr/models/modules/layers/tensorflow.py b/doctr/models/modules/layers/tensorflow.py index 68849fbf6e..b1019be778 100644 --- a/doctr/models/modules/layers/tensorflow.py +++ b/doctr/models/modules/layers/tensorflow.py @@ -7,7 +7,7 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import layers +from keras import layers from doctr.utils.repr import NestedObject diff --git a/doctr/models/modules/transformer/tensorflow.py b/doctr/models/modules/transformer/tensorflow.py index 403f99117d..eef4f3dbea 100644 --- a/doctr/models/modules/transformer/tensorflow.py +++ b/doctr/models/modules/transformer/tensorflow.py @@ -7,7 +7,7 @@ from typing import Any, Callable, Optional, Tuple import tensorflow as tf -from tensorflow.keras import layers +from keras import layers from doctr.utils.repr import NestedObject diff --git a/doctr/models/modules/vision_transformer/tensorflow.py b/doctr/models/modules/vision_transformer/tensorflow.py index 8386172eb1..a73aa4c706 100644 --- a/doctr/models/modules/vision_transformer/tensorflow.py +++ b/doctr/models/modules/vision_transformer/tensorflow.py @@ -7,7 +7,7 @@ from typing import Any, Tuple import tensorflow as tf -from tensorflow.keras import layers +from keras import layers from doctr.utils.repr import NestedObject diff --git a/doctr/models/recognition/crnn/tensorflow.py b/doctr/models/recognition/crnn/tensorflow.py index 5ec48c4f0e..d366bfc14b 100644 --- a/doctr/models/recognition/crnn/tensorflow.py +++ b/doctr/models/recognition/crnn/tensorflow.py @@ -7,8 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union import tensorflow as tf -from tensorflow.keras import layers -from tensorflow.keras.models import Model, Sequential +from keras import layers +from keras.models import Model, Sequential from doctr.datasets import VOCABS @@ -24,21 +24,21 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["legacy_french"], - "url": "https://doctr-static.mindee.com/models?id=v0.3.0/crnn_vgg16_bn-76b7f2c6.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_vgg16_bn-9c188f45.weights.h5&src=0", }, "crnn_mobilenet_v3_small": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_mobilenet_v3_small-7f36edec.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_small-54850265.weights.h5&src=0", }, "crnn_mobilenet_v3_large": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/crnn_mobilenet_v3_large-cccc50b1.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_large-c64045e5.weights.h5&src=0", }, } @@ -128,7 +128,7 @@ class CRNN(RecognitionModel, Model): def __init__( self, - feature_extractor: tf.keras.Model, + feature_extractor: Model, vocab: str, rnn_units: int = 128, exportable: bool = False, @@ -247,7 +247,8 @@ def _crnn( model = CRNN(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, _cfg["url"]) + # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning + load_pretrained_params(model, _cfg["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]) return model diff --git a/doctr/models/recognition/master/tensorflow.py b/doctr/models/recognition/master/tensorflow.py index a3ecadcc15..5b8192dee6 100644 --- a/doctr/models/recognition/master/tensorflow.py +++ b/doctr/models/recognition/master/tensorflow.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple import tensorflow as tf -from tensorflow.keras import Model, layers +from keras import Model, layers from doctr.datasets import VOCABS from doctr.models.classification import magc_resnet31 @@ -25,7 +25,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/master-a8232e9f.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/master-d7fdaeff.weights.h5&src=0", }, } @@ -51,7 +51,7 @@ class MASTER(_MASTER, Model): def __init__( self, - feature_extractor: tf.keras.Model, + feature_extractor: Model, vocab: str, d_model: int = 512, dff: int = 2048, @@ -292,7 +292,10 @@ def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool ) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"] + ) return model diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py index 1365a6ac12..bca7806903 100644 --- a/doctr/models/recognition/parseq/tensorflow.py +++ b/doctr/models/recognition/parseq/tensorflow.py @@ -10,7 +10,7 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import Model, layers +from keras import Model, layers from doctr.datasets import VOCABS from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward @@ -27,7 +27,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/parseq-24cf693e.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/parseq-4152a87e.weights.h5&src=0", }, } @@ -43,7 +43,7 @@ class CharEmbedding(layers.Layer): def __init__(self, vocab_size: int, d_model: int): super(CharEmbedding, self).__init__() - self.embedding = tf.keras.layers.Embedding(vocab_size, d_model) + self.embedding = layers.Embedding(vocab_size, d_model) self.d_model = d_model def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor: @@ -238,7 +238,7 @@ def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple def decode( self, target: tf.Tensor, - memory: tf, + memory: tf.Tensor, target_mask: Optional[tf.Tensor] = None, target_query: Optional[tf.Tensor] = None, **kwargs: Any, @@ -478,7 +478,10 @@ def _parseq( model = PARSeq(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"] + ) return model diff --git a/doctr/models/recognition/sar/tensorflow.py b/doctr/models/recognition/sar/tensorflow.py index e5e557c232..0776414c7a 100644 --- a/doctr/models/recognition/sar/tensorflow.py +++ b/doctr/models/recognition/sar/tensorflow.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple import tensorflow as tf -from tensorflow.keras import Model, Sequential, layers +from keras import Model, Sequential, layers from doctr.datasets import VOCABS from doctr.utils.repr import NestedObject @@ -24,7 +24,7 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/sar_resnet31-c41e32a5.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/sar_resnet31-5a58806c.weights.h5&src=0", }, } @@ -394,7 +394,10 @@ def _sar( model = SAR(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"] + ) return model diff --git a/doctr/models/recognition/vitstr/tensorflow.py b/doctr/models/recognition/vitstr/tensorflow.py index 9c5359dde2..985f49a470 100644 --- a/doctr/models/recognition/vitstr/tensorflow.py +++ b/doctr/models/recognition/vitstr/tensorflow.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple import tensorflow as tf -from tensorflow.keras import Model, layers +from keras import Model, layers from doctr.datasets import VOCABS @@ -23,14 +23,14 @@ "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vitstr_small-358fab2e.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_small-d28b8d92.weights.h5&src=0", }, "vitstr_base": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (32, 128, 3), "vocab": VOCABS["french"], - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vitstr_base-2889159a.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_base-9ad6eb84.weights.h5&src=0", }, } @@ -218,7 +218,10 @@ def _vitstr( model = ViTSTR(feat_extractor, cfg=_cfg, **kwargs) # Load pretrained parameters if pretrained: - load_pretrained_params(model, default_cfgs[arch]["url"]) + # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning + load_pretrained_params( + model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"] + ) return model diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py index 4c6f02c2a3..51a2bc69a5 100644 --- a/doctr/models/utils/tensorflow.py +++ b/doctr/models/utils/tensorflow.py @@ -4,13 +4,11 @@ # See LICENSE or go to for full license details. import logging -import os from typing import Any, Callable, List, Optional, Tuple, Union -from zipfile import ZipFile import tensorflow as tf import tf2onnx -from tensorflow.keras import Model, layers +from keras import Model, layers from doctr.utils.data import download_from_url @@ -40,22 +38,20 @@ def load_pretrained_params( model: Model, url: Optional[str] = None, hash_prefix: Optional[str] = None, - overwrite: bool = False, - internal_name: str = "weights", + skip_mismatch: bool = False, **kwargs: Any, ) -> None: """Load a set of parameters onto a model >>> from doctr.models import load_pretrained_params - >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip") + >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.weights.h5") Args: ---- model: the keras model to be loaded url: URL of the zipped set of parameters hash_prefix: first characters of SHA256 expected hash - overwrite: should the zip extraction be enforced if the archive has already been extracted - internal_name: name of the ckpt files + skip_mismatch: skip loading layers with mismatched shapes **kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url` """ if url is None: @@ -63,14 +59,12 @@ def load_pretrained_params( else: archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs) - # Unzip the archive - params_path = archive_path.parent.joinpath(archive_path.stem) - if not params_path.is_dir() or overwrite: - with ZipFile(archive_path, "r") as f: - f.extractall(path=params_path) + # Build the model + # NOTE: `model.build` is not an option because it doesn't runs in eager mode + _ = model(tf.ones((1, *model.cfg["input_shape"])), training=False) # Load weights - model.load_weights(f"{params_path}{os.sep}{internal_name}") + model.load_weights(archive_path, skip_mismatch=skip_mismatch) def conv_sequence( @@ -83,7 +77,7 @@ def conv_sequence( ) -> List[layers.Layer]: """Builds a convolutional-based layer sequence - >>> from tensorflow.keras import Sequential + >>> from keras import Sequential >>> from doctr.models import conv_sequence >>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3])) @@ -119,7 +113,7 @@ def conv_sequence( class IntermediateLayerGetter(Model): """Implements an intermediate layer getter - >>> from tensorflow.keras.applications import ResNet50 + >>> from keras.applications import ResNet50 >>> from doctr.models import IntermediateLayerGetter >>> target_layers = ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"] >>> feat_extractor = IntermediateLayerGetter(ResNet50(include_top=False, pooling=False), target_layers) diff --git a/pyproject.toml b/pyproject.toml index c0b209f535..aa0e02f98e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,8 @@ tf = [ # cf. https://github.com/mindee/doctr/pull/1461 "tensorflow>=2.11.0,<2.16.0", "tf2onnx>=1.16.0,<2.0.0", # cf. https://github.com/onnx/tensorflow-onnx/releases/tag/v1.16.0 + # TODO: This is a temporary fix until we can upgrade to a newer version of tensorflow + "numpy>=1.16.0,<2.0.0", ] torch = [ "torch>=1.12.0,<3.0.0", @@ -158,6 +160,7 @@ implicit_reexport = false module = [ "anyascii.*", "tensorflow.*", + "keras.*", "torchvision.*", "onnxruntime.*", "PIL.*", @@ -195,7 +198,7 @@ ignore = ["E402", "E203", "F403", "E731", "N812", "N817", "C408"] [tool.ruff.lint.isort] known-first-party = ["doctr", "app", "utils"] -known-third-party = ["tensorflow", "torch", "torchvision", "wandb", "tqdm", "fastapi", "onnxruntime", "cv2"] +known-third-party = ["tensorflow", "keras", "torch", "torchvision", "wandb", "tqdm", "fastapi", "onnxruntime", "cv2"] [tool.ruff.lint.per-file-ignores] "doctr/models/**.py" = ["N806", "F841"] diff --git a/references/classification/train_tensorflow_character.py b/references/classification/train_tensorflow_character.py index 580cf6fb1b..b2d24f2dbf 100644 --- a/references/classification/train_tensorflow_character.py +++ b/references/classification/train_tensorflow_character.py @@ -13,7 +13,7 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import mixed_precision +from keras import Model, mixed_precision, optimizers from tqdm.auto import tqdm from doctr.models import login_to_hub, push_to_hf_hub @@ -30,7 +30,7 @@ def record_lr( - model: tf.keras.Model, + model: Model, train_loader: DataLoader, batch_transforms, optimizer, @@ -176,6 +176,8 @@ def main(args): # Resume weights if isinstance(args.resume, str): + # Build the model first to load the weights + _ = model(tf.zeros((1, args.input_size, args.input_size, 3)), training=False) model.load_weights(args.resume) batch_transforms = T.Compose([ @@ -227,14 +229,14 @@ def main(args): return # Optimizer - scheduler = tf.keras.optimizers.schedules.ExponentialDecay( + scheduler = optimizers.schedules.ExponentialDecay( args.lr, decay_steps=args.epochs * len(train_loader), decay_rate=1 / (1e3), # final lr as a fraction of initial lr staircase=False, name="ExponentialDecay", ) - optimizer = tf.keras.optimizers.Adam( + optimizer = optimizers.Adam( learning_rate=scheduler, beta_1=0.95, beta_2=0.99, @@ -291,7 +293,7 @@ def main(args): val_loss, acc = evaluate(model, val_loader, batch_transforms) if val_loss < min_loss: print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - model.save_weights(f"./{exp_name}/weights") + model.save_weights(f"./{exp_name}.weights.h5") min_loss = val_loss print(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})") # W&B diff --git a/references/classification/train_tensorflow_orientation.py b/references/classification/train_tensorflow_orientation.py index ad25713df7..e063174944 100644 --- a/references/classification/train_tensorflow_orientation.py +++ b/references/classification/train_tensorflow_orientation.py @@ -13,7 +13,7 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import mixed_precision +from keras import Model, mixed_precision, optimizers from tqdm.auto import tqdm from doctr.models import login_to_hub, push_to_hf_hub @@ -44,7 +44,7 @@ def rnd_rotate(img: tf.Tensor, target): def record_lr( - model: tf.keras.Model, + model: Model, train_loader: DataLoader, batch_transforms, optimizer, @@ -187,6 +187,8 @@ def main(args): # Resume weights if isinstance(args.resume, str): + # Build the model first to load the weights + _ = model(tf.zeros((1, *input_size, 3)), training=False) model.load_weights(args.resume) batch_transforms = T.Compose([ @@ -237,14 +239,14 @@ def main(args): return # Optimizer - scheduler = tf.keras.optimizers.schedules.ExponentialDecay( + scheduler = optimizers.schedules.ExponentialDecay( args.lr, decay_steps=args.epochs * len(train_loader), decay_rate=1 / (1e3), # final lr as a fraction of initial lr staircase=False, name="ExponentialDecay", ) - optimizer = tf.keras.optimizers.Adam( + optimizer = optimizers.Adam( learning_rate=scheduler, beta_1=0.95, beta_2=0.99, @@ -301,7 +303,7 @@ def main(args): val_loss, acc = evaluate(model, val_loader, batch_transforms) if val_loss < min_loss: print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - model.save_weights(f"./{exp_name}/weights") + model.save_weights(f"./{exp_name}.weights.h5") min_loss = val_loss print(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})") # W&B diff --git a/references/detection/evaluate_tensorflow.py b/references/detection/evaluate_tensorflow.py index 139932f2c4..abf012ed83 100644 --- a/references/detection/evaluate_tensorflow.py +++ b/references/detection/evaluate_tensorflow.py @@ -14,7 +14,7 @@ from pathlib import Path import tensorflow as tf -from tensorflow.keras import mixed_precision +from keras import mixed_precision from tqdm import tqdm gpu_devices = tf.config.experimental.list_physical_devices("GPU") diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index 1312a6ea13..b9c14494ad 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -14,7 +14,7 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import mixed_precision +from keras import Model, mixed_precision, optimizers from tqdm.auto import tqdm from doctr.models import login_to_hub, push_to_hf_hub @@ -31,7 +31,7 @@ def record_lr( - model: tf.keras.Model, + model: Model, train_loader: DataLoader, batch_transforms, optimizer, @@ -58,7 +58,7 @@ def record_lr( # Forward, Backward & update with tf.GradientTape() as tape: - train_loss = model(images, targets, training=True)["loss"] + train_loss = model(images, target=targets, training=True)["loss"] grads = tape.gradient(train_loss, model.trainable_weights) if amp: @@ -90,7 +90,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False): images = batch_transforms(images) with tf.GradientTape() as tape: - train_loss = model(images, targets, training=True)["loss"] + train_loss = model(images, target=targets, training=True)["loss"] grads = tape.gradient(train_loss, model.trainable_weights) if amp: grads = optimizer.get_unscaled_gradients(grads) @@ -107,7 +107,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric): val_iter = iter(val_loader) for images, targets in tqdm(val_iter): images = batch_transforms(images) - out = model(images, targets, training=False, return_preds=True) + out = model(images, target=targets, training=False, return_preds=True) # Compute metric loc_preds = out["preds"] for target, loc_pred in zip(targets, loc_preds): @@ -184,6 +184,8 @@ def main(args): # Resume weights if isinstance(args.resume, str): + # Build the model first to load the weights + _ = model(tf.zeros((1, args.input_size, args.input_size, 3)), training=False) model.load_weights(args.resume) if isinstance(args.pretrained_backbone, str): @@ -278,7 +280,7 @@ def main(args): # Scheduler if args.sched == "exponential": - scheduler = tf.keras.optimizers.schedules.ExponentialDecay( + scheduler = optimizers.schedules.ExponentialDecay( args.lr, decay_steps=args.epochs * len(train_loader), decay_rate=1 / (25e4), # final lr as a fraction of initial lr @@ -286,7 +288,7 @@ def main(args): name="ExponentialDecay", ) elif args.sched == "poly": - scheduler = tf.keras.optimizers.schedules.PolynomialDecay( + scheduler = optimizers.schedules.PolynomialDecay( args.lr, decay_steps=args.epochs * len(train_loader), end_learning_rate=1e-7, @@ -295,7 +297,7 @@ def main(args): name="PolynomialDecay", ) # Optimizer - optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler, beta_1=0.95, beta_2=0.99, epsilon=1e-6, clipnorm=5) + optimizer = optimizers.Adam(learning_rate=scheduler, beta_1=0.95, beta_2=0.99, epsilon=1e-6, clipnorm=5) if args.amp: optimizer = mixed_precision.LossScaleOptimizer(optimizer) # LR Finder @@ -351,11 +353,11 @@ def main(args): val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric) if val_loss < min_loss: print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - model.save_weights(f"./{exp_name}/weights") + model.save_weights(f"./{exp_name}.weights.h5") min_loss = val_loss if args.save_interval_epoch: print(f"Saving state at epoch: {epoch + 1}") - model.save_weights(f"./{exp_name}_{epoch + 1}/weights") + model.save_weights(f"./{exp_name}_{epoch + 1}.weights.h5") log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " if any(val is None for val in (recall, precision, mean_iou)): log_msg += "(Undefined metric value, caused by empty GTs or predictions)" diff --git a/references/recognition/evaluate_tensorflow.py b/references/recognition/evaluate_tensorflow.py index 62651245c4..4c9d125285 100644 --- a/references/recognition/evaluate_tensorflow.py +++ b/references/recognition/evaluate_tensorflow.py @@ -11,7 +11,7 @@ import time import tensorflow as tf -from tensorflow.keras import mixed_precision +from keras import mixed_precision from tqdm import tqdm gpu_devices = tf.config.experimental.list_physical_devices("GPU") diff --git a/references/recognition/train_tensorflow.py b/references/recognition/train_tensorflow.py index 7f55142859..c76355a2f2 100644 --- a/references/recognition/train_tensorflow.py +++ b/references/recognition/train_tensorflow.py @@ -15,7 +15,7 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import mixed_precision +from keras import Model, mixed_precision, optimizers from tqdm.auto import tqdm from doctr.models import login_to_hub, push_to_hf_hub @@ -32,7 +32,7 @@ def record_lr( - model: tf.keras.Model, + model: Model, train_loader: DataLoader, batch_transforms, optimizer, @@ -59,7 +59,7 @@ def record_lr( # Forward, Backward & update with tf.GradientTape() as tape: - train_loss = model(images, targets, training=True)["loss"] + train_loss = model(images, target=targets, training=True)["loss"] grads = tape.gradient(train_loss, model.trainable_weights) if amp: @@ -91,7 +91,7 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, amp=False): images = batch_transforms(images) with tf.GradientTape() as tape: - train_loss = model(images, targets, training=True)["loss"] + train_loss = model(images, target=targets, training=True)["loss"] grads = tape.gradient(train_loss, model.trainable_weights) if amp: grads = optimizer.get_unscaled_gradients(grads) @@ -108,7 +108,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric): val_iter = iter(val_loader) for images, targets in tqdm(val_iter): images = batch_transforms(images) - out = model(images, targets, return_preds=True, training=False) + out = model(images, target=targets, return_preds=True, training=False) # Compute metric if len(out["preds"]): words, _ = zip(*out["preds"]) @@ -184,6 +184,8 @@ def main(args): ) # Resume weights if isinstance(args.resume, str): + # Build the model first to load the weights + _ = model(tf.zeros((1, args.input_size, 4 * args.input_size, 3)), training=False) model.load_weights(args.resume) # Metrics @@ -275,14 +277,14 @@ def main(args): return # Optimizer - scheduler = tf.keras.optimizers.schedules.ExponentialDecay( + scheduler = optimizers.schedules.ExponentialDecay( args.lr, decay_steps=args.epochs * len(train_loader), decay_rate=1 / (25e4), # final lr as a fraction of initial lr staircase=False, name="ExponentialDecay", ) - optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler, beta_1=0.95, beta_2=0.99, epsilon=1e-6, clipnorm=5) + optimizer = optimizers.Adam(learning_rate=scheduler, beta_1=0.95, beta_2=0.99, epsilon=1e-6, clipnorm=5) if args.amp: optimizer = mixed_precision.LossScaleOptimizer(optimizer) # LR Finder @@ -343,7 +345,7 @@ def main(args): val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric) if val_loss < min_loss: print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - model.save_weights(f"./{exp_name}/weights") + model.save_weights(f"./{exp_name}.weights.h5") min_loss = val_loss print( f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " diff --git a/tests/pytorch/test_models_classification_pt.py b/tests/pytorch/test_models_classification_pt.py index 4c0b571da9..b3d25af173 100644 --- a/tests/pytorch/test_models_classification_pt.py +++ b/tests/pytorch/test_models_classification_pt.py @@ -54,7 +54,7 @@ def test_classification_architectures(arch_name, input_shape, output_size): model = classification.__dict__[arch_name](pretrained=True).eval() _test_classification(model, input_shape, output_size) # Check that you can pretrained everything up until the last layer - classification.__dict__[arch_name](pretrained=True, num_classes=10) + assert classification.__dict__[arch_name](pretrained=True, num_classes=10) @pytest.mark.parametrize( diff --git a/tests/tensorflow/test_models_classification_tf.py b/tests/tensorflow/test_models_classification_tf.py index 731e4dbd8b..11f4ea4114 100644 --- a/tests/tensorflow/test_models_classification_tf.py +++ b/tests/tensorflow/test_models_classification_tf.py @@ -2,6 +2,7 @@ import tempfile import cv2 +import keras import numpy as np import onnxruntime import psutil @@ -37,7 +38,7 @@ def test_classification_architectures(arch_name, input_shape, output_size): # Model batch_size = 2 - tf.keras.backend.clear_session() + keras.backend.clear_session() model = classification.__dict__[arch_name](pretrained=True, include_top=True, input_shape=input_shape) # Forward out = model(tf.random.uniform(shape=[batch_size, *input_shape], maxval=1, dtype=tf.float32)) @@ -45,6 +46,11 @@ def test_classification_architectures(arch_name, input_shape, output_size): assert isinstance(out, tf.Tensor) assert out.dtype == tf.float32 assert out.numpy().shape == (batch_size, *output_size) + # Check that you can load pretrained up to the classification layer with differing number of classes to fine-tune + keras.backend.clear_session() + assert classification.__dict__[arch_name]( + pretrained=True, include_top=True, input_shape=input_shape, num_classes=10 + ) @pytest.mark.parametrize( @@ -57,7 +63,7 @@ def test_classification_architectures(arch_name, input_shape, output_size): def test_classification_models(arch_name, input_shape): batch_size = 8 reco_model = classification.__dict__[arch_name](pretrained=True, input_shape=input_shape) - assert isinstance(reco_model, tf.keras.Model) + assert isinstance(reco_model, keras.Model) input_tensor = tf.random.uniform(shape=[batch_size, *input_shape], minval=0, maxval=1) out = reco_model(input_tensor) @@ -226,7 +232,7 @@ def test_page_orientation_model(mock_payslip): def test_models_onnx_export(arch_name, input_shape, output_size): # Model batch_size = 2 - tf.keras.backend.clear_session() + keras.backend.clear_session() if "orientation" in arch_name: model = classification.__dict__[arch_name](pretrained=True, input_shape=input_shape) else: diff --git a/tests/tensorflow/test_models_detection_tf.py b/tests/tensorflow/test_models_detection_tf.py index 2e627b9e4d..ba5f50542b 100644 --- a/tests/tensorflow/test_models_detection_tf.py +++ b/tests/tensorflow/test_models_detection_tf.py @@ -2,6 +2,7 @@ import os import tempfile +import keras import numpy as np import onnxruntime import psutil @@ -37,13 +38,13 @@ ) def test_detection_models(arch_name, input_shape, output_size, out_prob, train_mode): batch_size = 2 - tf.keras.backend.clear_session() + keras.backend.clear_session() if arch_name == "fast_tiny_rep": model = reparameterize(detection.fast_tiny(pretrained=True, input_shape=input_shape)) train_mode = False # Reparameterized model is not trainable else: model = detection.__dict__[arch_name](pretrained=True, input_shape=input_shape) - assert isinstance(model, tf.keras.Model) + assert isinstance(model, keras.Model) input_tensor = tf.random.uniform(shape=[batch_size, *input_shape], minval=0, maxval=1) target = [ {CLASS_NAME: np.array([[0.5, 0.5, 1, 1], [0.5, 0.5, 0.8, 0.8]], dtype=np.float32)}, @@ -152,7 +153,7 @@ def test_rotated_detectionpredictor(mock_pdf): ) def test_detection_zoo(arch_name): # Model - tf.keras.backend.clear_session() + keras.backend.clear_session() predictor = detection.zoo.detection_predictor(arch_name, pretrained=False) # object check assert isinstance(predictor, DetectionPredictor) @@ -177,7 +178,7 @@ def test_fast_reparameterization(): base_model_params = np.sum([np.prod(v.shape) for v in base_model.trainable_variables]) assert math.isclose(base_model_params, 13535296) # base model params base_out = base_model(dummy_input, training=False)["logits"] - tf.keras.backend.clear_session() + keras.backend.clear_session() rep_model = reparameterize(base_model) rep_model_params = np.sum([np.prod(v.shape) for v in base_model.trainable_variables]) assert math.isclose(rep_model_params, 8520256) # reparameterized model params @@ -241,7 +242,7 @@ def test_dilate(): def test_models_onnx_export(arch_name, input_shape, output_size): # Model batch_size = 2 - tf.keras.backend.clear_session() + keras.backend.clear_session() if arch_name == "fast_tiny_rep": model = reparameterize(detection.fast_tiny(pretrained=True, exportable=True, input_shape=input_shape)) else: diff --git a/tests/tensorflow/test_models_factory.py b/tests/tensorflow/test_models_factory.py index 9b1ad2e166..0860d8612c 100644 --- a/tests/tensorflow/test_models_factory.py +++ b/tests/tensorflow/test_models_factory.py @@ -2,8 +2,8 @@ import os import tempfile +import keras import pytest -import tensorflow as tf from doctr import models from doctr.models.factory import _save_model_and_config_for_hf_hub, from_hub, push_to_hf_hub @@ -25,40 +25,39 @@ def test_push_to_hf_hub(): @pytest.mark.parametrize( "arch_name, task_name, dummy_model_id", [ - ["vgg16_bn_r", "classification", "Felix92/doctr-dummy-tf-vgg16-bn-r"], - ["resnet18", "classification", "Felix92/doctr-dummy-tf-resnet18"], - ["resnet31", "classification", "Felix92/doctr-dummy-tf-resnet31"], - ["resnet34", "classification", "Felix92/doctr-dummy-tf-resnet34"], - ["resnet34_wide", "classification", "Felix92/doctr-dummy-tf-resnet34-wide"], - ["resnet50", "classification", "Felix92/doctr-dummy-tf-resnet50"], - ["magc_resnet31", "classification", "Felix92/doctr-dummy-tf-magc-resnet31"], - ["mobilenet_v3_large", "classification", "Felix92/doctr-dummy-tf-mobilenet-v3-large"], - ["vit_b", "classification", "Felix92/doctr-dummy-tf-vit-b"], - ["textnet_tiny", "classification", "Felix92/doctr-dummy-tf-textnet-tiny"], - ["db_resnet50", "detection", "Felix92/doctr-dummy-tf-db-resnet50"], - ["db_mobilenet_v3_large", "detection", "Felix92/doctr-dummy-tf-db-mobilenet-v3-large"], - ["linknet_resnet18", "detection", "Felix92/doctr-dummy-tf-linknet-resnet18"], - ["linknet_resnet34", "detection", "Felix92/doctr-dummy-tf-linknet-resnet34"], - ["linknet_resnet50", "detection", "Felix92/doctr-dummy-tf-linknet-resnet50"], - ["crnn_vgg16_bn", "recognition", "Felix92/doctr-dummy-tf-crnn-vgg16-bn"], - ["crnn_mobilenet_v3_large", "recognition", "Felix92/doctr-dummy-tf-crnn-mobilenet-v3-large"], - ["sar_resnet31", "recognition", "Felix92/doctr-dummy-tf-sar-resnet31"], - ["master", "recognition", "Felix92/doctr-dummy-tf-master"], - ["vitstr_small", "recognition", "Felix92/doctr-dummy-tf-vitstr-small"], - ["parseq", "recognition", "Felix92/doctr-dummy-tf-parseq"], + ["vgg16_bn_r", "classification", "Felix92/doctr-dummy-tf-vgg16-bn-r-v2"], + ["resnet18", "classification", "Felix92/doctr-dummy-tf-resnet18-v2"], + ["resnet31", "classification", "Felix92/doctr-dummy-tf-resnet31-v2"], + ["resnet34", "classification", "Felix92/doctr-dummy-tf-resnet34-v2"], + ["resnet34_wide", "classification", "Felix92/doctr-dummy-tf-resnet34-wide-v2"], + ["resnet50", "classification", "Felix92/doctr-dummy-tf-resnet50-v2"], + ["magc_resnet31", "classification", "Felix92/doctr-dummy-tf-magc-resnet31-v2"], + ["mobilenet_v3_large", "classification", "Felix92/doctr-dummy-tf-mobilenet-v3-large-v2"], + ["vit_b", "classification", "Felix92/doctr-dummy-tf-vit-b-v2"], + ["textnet_tiny", "classification", "Felix92/doctr-dummy-tf-textnet-tiny-v2"], + ["db_resnet50", "detection", "Felix92/doctr-dummy-tf-db-resnet50-v2"], + ["db_mobilenet_v3_large", "detection", "Felix92/doctr-dummy-tf-db-mobilenet-v3-large-v2"], + ["linknet_resnet18", "detection", "Felix92/doctr-dummy-tf-linknet-resnet18-v2"], + ["linknet_resnet50", "detection", "Felix92/doctr-dummy-tf-linknet-resnet50-v2"], + ["linknet_resnet34", "detection", "Felix92/doctr-dummy-tf-linknet-resnet34-v2"], + ["crnn_vgg16_bn", "recognition", "Felix92/doctr-dummy-tf-crnn-vgg16-bn-v2"], + ["crnn_mobilenet_v3_large", "recognition", "Felix92/doctr-dummy-tf-crnn-mobilenet-v3-large-v2"], + ["sar_resnet31", "recognition", "Felix92/doctr-dummy-tf-sar-resnet31-v2"], + ["master", "recognition", "Felix92/doctr-dummy-tf-master-v2"], + ["vitstr_small", "recognition", "Felix92/doctr-dummy-tf-vitstr-small-v2"], + ["parseq", "recognition", "Felix92/doctr-dummy-tf-parseq-v2"], ], ) def test_models_for_hub(arch_name, task_name, dummy_model_id, tmpdir): with tempfile.TemporaryDirectory() as tmp_dir: - tf.keras.backend.clear_session() + keras.backend.clear_session() model = models.__dict__[task_name].__dict__[arch_name](pretrained=True) _save_model_and_config_for_hf_hub(model, arch=arch_name, task=task_name, save_dir=tmp_dir) assert hasattr(model, "cfg") assert len(os.listdir(tmp_dir)) == 2 - assert os.path.exists(tmp_dir + "/tf_model") - assert len(os.listdir(tmp_dir + "/tf_model")) == 3 + assert os.path.exists(tmp_dir + "/tf_model.weights.h5") assert os.path.exists(tmp_dir + "/config.json") tmp_config = json.load(open(tmp_dir + "/config.json")) assert arch_name == tmp_config["arch"] @@ -66,6 +65,6 @@ def test_models_for_hub(arch_name, task_name, dummy_model_id, tmpdir): assert all(key in model.cfg.keys() for key in tmp_config.keys()) # test from hub - tf.keras.backend.clear_session() + keras.backend.clear_session() hub_model = from_hub(repo_id=dummy_model_id) assert isinstance(hub_model, type(model)) diff --git a/tests/tensorflow/test_models_recognition_tf.py b/tests/tensorflow/test_models_recognition_tf.py index b58272d1de..7da1cb534a 100644 --- a/tests/tensorflow/test_models_recognition_tf.py +++ b/tests/tensorflow/test_models_recognition_tf.py @@ -2,6 +2,7 @@ import shutil import tempfile +import keras import numpy as np import onnxruntime import psutil @@ -37,10 +38,10 @@ ["parseq", (32, 128, 3)], ], ) -def test_recognition_models(arch_name, input_shape, train_mode): +def test_recognition_models(arch_name, input_shape, train_mode, mock_vocab): batch_size = 4 - reco_model = recognition.__dict__[arch_name](pretrained=True, input_shape=input_shape) - assert isinstance(reco_model, tf.keras.Model) + reco_model = recognition.__dict__[arch_name](vocab=mock_vocab, pretrained=True, input_shape=input_shape) + assert isinstance(reco_model, keras.Model) input_tensor = tf.random.uniform(shape=[batch_size, *input_shape], minval=0, maxval=1) target = ["i", "am", "a", "jedi"] @@ -194,7 +195,7 @@ def test_recognition_zoo_error(): def test_models_onnx_export(arch_name, input_shape): # Model batch_size = 2 - tf.keras.backend.clear_session() + keras.backend.clear_session() model = recognition.__dict__[arch_name](pretrained=True, exportable=True, input_shape=input_shape) # SAR, MASTER, ViTSTR export currently only available with constant batch size if arch_name in ["sar_resnet31", "master", "vitstr_small", "parseq"]: diff --git a/tests/tensorflow/test_models_utils_tf.py b/tests/tensorflow/test_models_utils_tf.py index b83e60c0ee..b57b41b14b 100644 --- a/tests/tensorflow/test_models_utils_tf.py +++ b/tests/tensorflow/test_models_utils_tf.py @@ -2,9 +2,9 @@ import pytest import tensorflow as tf -from tensorflow.keras import Sequential, layers -from tensorflow.keras.applications import ResNet50 +from keras.applications import ResNet50 +from doctr.models.classification import mobilenet_v3_small from doctr.models.utils import ( IntermediateLayerGetter, _bf16_to_float32, @@ -27,20 +27,18 @@ def test_bf16_to_float32(): def test_load_pretrained_params(tmpdir_factory): - model = Sequential([layers.Dense(8, activation="relu", input_shape=(4,)), layers.Dense(4)]) + model = mobilenet_v3_small(pretrained=False) # Retrieve this URL - url = "https://doctr-static.mindee.com/models?id=v0.1-models/tmp_checkpoint-4a98e492.zip&src=0" + url = "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small-3fcebad7.weights.h5&src=0" # Temp cache dir cache_dir = tmpdir_factory.mktemp("cache") # Pass an incorrect hash with pytest.raises(ValueError): - load_pretrained_params(model, url, "mywronghash", cache_dir=str(cache_dir), internal_name="") + load_pretrained_params(model, url, "mywronghash", cache_dir=str(cache_dir)) # Let tit resolve the hash from the file name - load_pretrained_params(model, url, cache_dir=str(cache_dir), internal_name="") - # Check that the file was downloaded & the archive extracted - assert os.path.exists(cache_dir.join("models").join("tmp_checkpoint-4a98e492")) - # Check that archive was deleted - assert os.path.exists(cache_dir.join("models").join("tmp_checkpoint-4a98e492.zip")) + load_pretrained_params(model, url, cache_dir=str(cache_dir)) + # Check that the file was downloaded + assert os.path.exists(cache_dir.join("models").join("mobilenet_v3_small-3fcebad7.weights.h5")) def test_conv_sequence():