Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF] First changes on the road to Keras v3 #1724

Merged
merged 7 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions docs/source/using_doctr/custom_models_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@ This section shows how you can easily load a custom trained model in docTR.

# Load custom detection model
det_model = db_resnet50(pretrained=False, pretrained_backbone=False)
det_model.load_weights("<path_to_checkpoint>/weights")
det_model.load_weights("<path_to_checkpoint>")
predictor = ocr_predictor(det_arch=det_model, reco_arch="vitstr_small", pretrained=True)

# Load custom recognition model
reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False)
reco_model.load_weights("<path_to_checkpoint>/weights")
reco_model.load_weights("<path_to_checkpoint>")
predictor = ocr_predictor(det_arch="linknet_resnet18", reco_arch=reco_model, pretrained=True)

# Load custom detection and recognition model
det_model = db_resnet50(pretrained=False, pretrained_backbone=False)
det_model.load_weights("<path_to_checkpoint>/weights")
det_model.load_weights("<path_to_checkpoint>")
reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False)
reco_model.load_weights("<path_to_checkpoint>/weights")
reco_model.load_weights("<path_to_checkpoint>")
predictor = ocr_predictor(det_arch=det_model, reco_arch=reco_model, pretrained=False)

.. tab:: PyTorch
Expand Down Expand Up @@ -77,7 +77,7 @@ Load a custom recognition model trained on another vocabulary as the default one
from doctr.datasets import VOCABS

reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False, vocab=VOCABS["german"])
reco_model.load_weights("<path_to_checkpoint>/weights")
reco_model.load_weights("<path_to_checkpoint>")

predictor = ocr_predictor(det_arch='linknet_resnet18', reco_arch=reco_model, pretrained=True)

Expand Down Expand Up @@ -106,7 +106,7 @@ Load a custom trained KIE detection model:
from doctr.models import kie_predictor, db_resnet50

det_model = db_resnet50(pretrained=False, pretrained_backbone=False, class_names=['total', 'date'])
det_model.load_weights("<path_to_checkpoint>/weights")
det_model.load_weights("<path_to_checkpoint>")
kie_predictor(det_arch=det_model, reco_arch='crnn_vgg16_bn', pretrained=True)

.. tab:: PyTorch
Expand Down Expand Up @@ -136,9 +136,9 @@ Load a model with customized Preprocessor:
from doctr.models import db_resnet50, crnn_vgg16_bn

det_model = db_resnet50(pretrained=False, pretrained_backbone=False)
det_model.load_weights("<path_to_checkpoint>/weights")
det_model.load_weights("<path_to_checkpoint>")
reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False)
reco_model.load_weights("<path_to_checkpoint>/weights")
reco_model.load_weights("<path_to_checkpoint>")

det_predictor = DetectionPredictor(
PreProcessor(
Expand Down Expand Up @@ -233,9 +233,9 @@ Loading your custom trained orientation classification model
from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor

custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=False)
custom_page_orientation_model.load_weights("<path_to_checkpoint>/weights")
custom_page_orientation_model.load_weights("<path_to_checkpoint>")
custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=False)
custom_crop_orientation_model.load_weights("<path_to_checkpoint>/weights")
custom_crop_orientation_model.load_weights("<path_to_checkpoint>")

predictor = ocr_predictor(
pretrained=True,
Expand Down
2 changes: 1 addition & 1 deletion docs/source/using_doctr/using_model_export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Advantages:
.. code:: python3

import tensorflow as tf
from tensorflow.keras import mixed_precision
from keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')
predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True)

Expand Down
2 changes: 1 addition & 1 deletion doctr/io/image/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import numpy as np
import tensorflow as tf
from keras.utils import img_to_array
from PIL import Image
from tensorflow.keras.utils import img_to_array

from doctr.utils.common_types import AbstractPath

Expand Down
15 changes: 10 additions & 5 deletions doctr/models/classification/magc_resnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from typing import Any, Dict, List, Optional, Tuple

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from keras import activations, layers
from keras.models import Sequential

from doctr.datasets import VOCABS

Expand All @@ -26,7 +26,7 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.6.0/magc_resnet31-addbb705.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/magc_resnet31-16aa7d71.weights.h5&src=0",
},
}

Expand Down Expand Up @@ -57,6 +57,7 @@ def __init__(
self.headers = headers # h
self.inplanes = inplanes # C
self.attn_scale = attn_scale
self.ratio = ratio
self.planes = int(inplanes * ratio)

self.single_header_inplanes = int(inplanes / headers) # C / h
Expand Down Expand Up @@ -97,7 +98,7 @@ def context_modeling(self, inputs: tf.Tensor) -> tf.Tensor:
if self.attn_scale and self.headers > 1:
context_mask = context_mask / math.sqrt(self.single_header_inplanes)
# B*h, 1, H*W, 1
context_mask = tf.keras.activations.softmax(context_mask, axis=2)
context_mask = activations.softmax(context_mask, axis=2)

# Compute context
# B*h, 1, C/h, 1
Expand Down Expand Up @@ -153,7 +154,11 @@ def _magc_resnet(
)
# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs[arch]["url"])
# The number of classes is not the same as the number of classes in the pretrained model =>
# skip the mismatching layers for fine tuning
load_pretrained_params(
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
)

return model

Expand Down
22 changes: 13 additions & 9 deletions doctr/models/classification/mobilenet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from keras import layers
from keras.models import Sequential

from ....datasets import VOCABS
from ...utils import conv_sequence, load_pretrained_params
Expand All @@ -32,42 +32,42 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large-47d25d7e.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large-d857506e.weights.h5&src=0",
},
"mobilenet_v3_large_r": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large_r-a108e192.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large_r-eef2e3c6.weights.h5&src=0",
},
"mobilenet_v3_small": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small-8a32c32c.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small-3fcebad7.weights.h5&src=0",
},
"mobilenet_v3_small_r": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-3d61452e.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_r-dd50218d.weights.h5&src=0",
},
"mobilenet_v3_small_crop_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (128, 128, 3),
"classes": [0, -90, 180, 90],
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/classif_mobilenet_v3_small-1ea8db03.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_crop_orientation-ef019b6b.weights.h5&src=0",
},
"mobilenet_v3_small_page_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (512, 512, 3),
"classes": [0, -90, 180, 90],
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_page_orientation-aec9553e.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_page_orientation-0071d55d.weights.h5&src=0",
},
}

Expand Down Expand Up @@ -297,7 +297,11 @@ def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwa
)
# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs[arch]["url"])
# The number of classes is not the same as the number of classes in the pretrained model =>
# skip the mismatching layers for fine tuning
load_pretrained_params(
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
)

return model

Expand Down
6 changes: 3 additions & 3 deletions doctr/models/classification/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import Model

from doctr.models.preprocessor import PreProcessor
from doctr.utils.repr import NestedObject
Expand All @@ -30,10 +30,10 @@ class OrientationPredictor(NestedObject):
def __init__(
self,
pre_processor: Optional[PreProcessor],
model: Optional[keras.Model],
model: Optional[Model],
) -> None:
self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
self.model = model if isinstance(model, keras.Model) else None
self.model = model if isinstance(model, Model) else None

def __call__(
self,
Expand Down
30 changes: 20 additions & 10 deletions doctr/models/classification/resnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from typing import Any, Callable, Dict, List, Optional, Tuple

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Sequential
from keras import layers
from keras.applications import ResNet50
from keras.models import Sequential

from doctr.datasets import VOCABS

Expand All @@ -24,35 +24,35 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/resnet18-d4634669.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet18-f42d3854.weights.h5&src=0",
},
"resnet31": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet31-5a47a60b.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet31-ab75f78c.weights.h5&src=0",
},
"resnet34": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34-5dcc97ca.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34-03967df9.weights.h5&src=0",
},
"resnet50": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet50-e75e4cdf.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet50-82358f34.weights.h5&src=0",
},
"resnet34_wide": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34_wide-c1271816.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34_wide-b18fdf79.weights.h5&src=0",
},
}

Expand Down Expand Up @@ -212,7 +212,11 @@ def _resnet(
)
# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs[arch]["url"])
# The number of classes is not the same as the number of classes in the pretrained model =>
# skip the mismatching layers for fine tuning
load_pretrained_params(
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
)

return model

Expand Down Expand Up @@ -357,7 +361,13 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:

# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs["resnet50"]["url"])
# The number of classes is not the same as the number of classes in the pretrained model =>
# skip the mismatching layers for fine tuning
load_pretrained_params(
model,
default_cfgs["resnet50"]["url"],
skip_mismatch=kwargs["num_classes"] != len(default_cfgs["resnet50"]["classes"]),
)

return model

Expand Down
14 changes: 9 additions & 5 deletions doctr/models/classification/textnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple

from tensorflow.keras import Sequential, layers
from keras import Sequential, layers

from doctr.datasets import VOCABS

Expand All @@ -22,21 +22,21 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_tiny-fe9cc245.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_tiny-a29eeb4a.weights.h5&src=0",
},
"textnet_small": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_small-29c39c82.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_small-1c2df0e3.weights.h5&src=0",
},
"textnet_base": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_base-168aa82c.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_base-8b4b89bc.weights.h5&src=0",
},
}

Expand Down Expand Up @@ -113,7 +113,11 @@ def _textnet(
model = TextNet(cfg=_cfg, **kwargs)
# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs[arch]["url"])
# The number of classes is not the same as the number of classes in the pretrained model =>
# skip the mismatching layers for fine tuning
load_pretrained_params(
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
)

return model

Expand Down
12 changes: 8 additions & 4 deletions doctr/models/classification/vgg/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple

from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from keras import layers
from keras.models import Sequential

from doctr.datasets import VOCABS

Expand All @@ -22,7 +22,7 @@
"std": (1.0, 1.0, 1.0),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/vgg16_bn_r-c5836cea.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vgg16_bn_r-b4d69212.weights.h5&src=0",
},
}

Expand Down Expand Up @@ -83,7 +83,11 @@ def _vgg(
model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs)
# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs[arch]["url"])
# The number of classes is not the same as the number of classes in the pretrained model =>
# skip the mismatching layers for fine tuning
load_pretrained_params(
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
)

return model

Expand Down
Loading
Loading