Skip to content

Commit

Permalink
[TF] First changes on the road to Keras v3 (#1724)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Oct 1, 2024
1 parent df762ed commit dccc26b
Show file tree
Hide file tree
Showing 37 changed files with 287 additions and 207 deletions.
20 changes: 10 additions & 10 deletions docs/source/using_doctr/custom_models_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@ This section shows how you can easily load a custom trained model in docTR.
# Load custom detection model
det_model = db_resnet50(pretrained=False, pretrained_backbone=False)
det_model.load_weights("<path_to_checkpoint>/weights")
det_model.load_weights("<path_to_checkpoint>")
predictor = ocr_predictor(det_arch=det_model, reco_arch="vitstr_small", pretrained=True)
# Load custom recognition model
reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False)
reco_model.load_weights("<path_to_checkpoint>/weights")
reco_model.load_weights("<path_to_checkpoint>")
predictor = ocr_predictor(det_arch="linknet_resnet18", reco_arch=reco_model, pretrained=True)
# Load custom detection and recognition model
det_model = db_resnet50(pretrained=False, pretrained_backbone=False)
det_model.load_weights("<path_to_checkpoint>/weights")
det_model.load_weights("<path_to_checkpoint>")
reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False)
reco_model.load_weights("<path_to_checkpoint>/weights")
reco_model.load_weights("<path_to_checkpoint>")
predictor = ocr_predictor(det_arch=det_model, reco_arch=reco_model, pretrained=False)
.. tab:: PyTorch
Expand Down Expand Up @@ -77,7 +77,7 @@ Load a custom recognition model trained on another vocabulary as the default one
from doctr.datasets import VOCABS
reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False, vocab=VOCABS["german"])
reco_model.load_weights("<path_to_checkpoint>/weights")
reco_model.load_weights("<path_to_checkpoint>")
predictor = ocr_predictor(det_arch='linknet_resnet18', reco_arch=reco_model, pretrained=True)
Expand Down Expand Up @@ -106,7 +106,7 @@ Load a custom trained KIE detection model:
from doctr.models import kie_predictor, db_resnet50
det_model = db_resnet50(pretrained=False, pretrained_backbone=False, class_names=['total', 'date'])
det_model.load_weights("<path_to_checkpoint>/weights")
det_model.load_weights("<path_to_checkpoint>")
kie_predictor(det_arch=det_model, reco_arch='crnn_vgg16_bn', pretrained=True)
.. tab:: PyTorch
Expand Down Expand Up @@ -136,9 +136,9 @@ Load a model with customized Preprocessor:
from doctr.models import db_resnet50, crnn_vgg16_bn
det_model = db_resnet50(pretrained=False, pretrained_backbone=False)
det_model.load_weights("<path_to_checkpoint>/weights")
det_model.load_weights("<path_to_checkpoint>")
reco_model = crnn_vgg16_bn(pretrained=False, pretrained_backbone=False)
reco_model.load_weights("<path_to_checkpoint>/weights")
reco_model.load_weights("<path_to_checkpoint>")
det_predictor = DetectionPredictor(
PreProcessor(
Expand Down Expand Up @@ -233,9 +233,9 @@ Loading your custom trained orientation classification model
from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor
custom_page_orientation_model = mobilenet_v3_small_page_orientation(pretrained=False)
custom_page_orientation_model.load_weights("<path_to_checkpoint>/weights")
custom_page_orientation_model.load_weights("<path_to_checkpoint>")
custom_crop_orientation_model = mobilenet_v3_small_crop_orientation(pretrained=False)
custom_crop_orientation_model.load_weights("<path_to_checkpoint>/weights")
custom_crop_orientation_model.load_weights("<path_to_checkpoint>")
predictor = ocr_predictor(
pretrained=True,
Expand Down
2 changes: 1 addition & 1 deletion docs/source/using_doctr/using_model_export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Advantages:
.. code:: python3
import tensorflow as tf
from tensorflow.keras import mixed_precision
from keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')
predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True)
Expand Down
2 changes: 1 addition & 1 deletion doctr/io/image/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import numpy as np
import tensorflow as tf
from keras.utils import img_to_array
from PIL import Image
from tensorflow.keras.utils import img_to_array

from doctr.utils.common_types import AbstractPath

Expand Down
15 changes: 10 additions & 5 deletions doctr/models/classification/magc_resnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from typing import Any, Dict, List, Optional, Tuple

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from keras import activations, layers
from keras.models import Sequential

from doctr.datasets import VOCABS

Expand All @@ -26,7 +26,7 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.6.0/magc_resnet31-addbb705.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/magc_resnet31-16aa7d71.weights.h5&src=0",
},
}

Expand Down Expand Up @@ -57,6 +57,7 @@ def __init__(
self.headers = headers # h
self.inplanes = inplanes # C
self.attn_scale = attn_scale
self.ratio = ratio
self.planes = int(inplanes * ratio)

self.single_header_inplanes = int(inplanes / headers) # C / h
Expand Down Expand Up @@ -97,7 +98,7 @@ def context_modeling(self, inputs: tf.Tensor) -> tf.Tensor:
if self.attn_scale and self.headers > 1:
context_mask = context_mask / math.sqrt(self.single_header_inplanes)
# B*h, 1, H*W, 1
context_mask = tf.keras.activations.softmax(context_mask, axis=2)
context_mask = activations.softmax(context_mask, axis=2)

# Compute context
# B*h, 1, C/h, 1
Expand Down Expand Up @@ -153,7 +154,11 @@ def _magc_resnet(
)
# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs[arch]["url"])
# The number of classes is not the same as the number of classes in the pretrained model =>
# skip the mismatching layers for fine tuning
load_pretrained_params(
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
)

return model

Expand Down
22 changes: 13 additions & 9 deletions doctr/models/classification/mobilenet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from keras import layers
from keras.models import Sequential

from ....datasets import VOCABS
from ...utils import conv_sequence, load_pretrained_params
Expand All @@ -32,42 +32,42 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large-47d25d7e.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large-d857506e.weights.h5&src=0",
},
"mobilenet_v3_large_r": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large_r-a108e192.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large_r-eef2e3c6.weights.h5&src=0",
},
"mobilenet_v3_small": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small-8a32c32c.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small-3fcebad7.weights.h5&src=0",
},
"mobilenet_v3_small_r": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-3d61452e.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_r-dd50218d.weights.h5&src=0",
},
"mobilenet_v3_small_crop_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (128, 128, 3),
"classes": [0, -90, 180, 90],
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/classif_mobilenet_v3_small-1ea8db03.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_crop_orientation-ef019b6b.weights.h5&src=0",
},
"mobilenet_v3_small_page_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (512, 512, 3),
"classes": [0, -90, 180, 90],
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_page_orientation-aec9553e.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_page_orientation-0071d55d.weights.h5&src=0",
},
}

Expand Down Expand Up @@ -297,7 +297,11 @@ def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwa
)
# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs[arch]["url"])
# The number of classes is not the same as the number of classes in the pretrained model =>
# skip the mismatching layers for fine tuning
load_pretrained_params(
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
)

return model

Expand Down
6 changes: 3 additions & 3 deletions doctr/models/classification/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import Model

from doctr.models.preprocessor import PreProcessor
from doctr.utils.repr import NestedObject
Expand All @@ -30,10 +30,10 @@ class OrientationPredictor(NestedObject):
def __init__(
self,
pre_processor: Optional[PreProcessor],
model: Optional[keras.Model],
model: Optional[Model],
) -> None:
self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None
self.model = model if isinstance(model, keras.Model) else None
self.model = model if isinstance(model, Model) else None

def __call__(
self,
Expand Down
30 changes: 20 additions & 10 deletions doctr/models/classification/resnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from typing import Any, Callable, Dict, List, Optional, Tuple

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Sequential
from keras import layers
from keras.applications import ResNet50
from keras.models import Sequential

from doctr.datasets import VOCABS

Expand All @@ -24,35 +24,35 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/resnet18-d4634669.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet18-f42d3854.weights.h5&src=0",
},
"resnet31": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet31-5a47a60b.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet31-ab75f78c.weights.h5&src=0",
},
"resnet34": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34-5dcc97ca.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34-03967df9.weights.h5&src=0",
},
"resnet50": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet50-e75e4cdf.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet50-82358f34.weights.h5&src=0",
},
"resnet34_wide": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.5.0/resnet34_wide-c1271816.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34_wide-b18fdf79.weights.h5&src=0",
},
}

Expand Down Expand Up @@ -212,7 +212,11 @@ def _resnet(
)
# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs[arch]["url"])
# The number of classes is not the same as the number of classes in the pretrained model =>
# skip the mismatching layers for fine tuning
load_pretrained_params(
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
)

return model

Expand Down Expand Up @@ -357,7 +361,13 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:

# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs["resnet50"]["url"])
# The number of classes is not the same as the number of classes in the pretrained model =>
# skip the mismatching layers for fine tuning
load_pretrained_params(
model,
default_cfgs["resnet50"]["url"],
skip_mismatch=kwargs["num_classes"] != len(default_cfgs["resnet50"]["classes"]),
)

return model

Expand Down
14 changes: 9 additions & 5 deletions doctr/models/classification/textnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple

from tensorflow.keras import Sequential, layers
from keras import Sequential, layers

from doctr.datasets import VOCABS

Expand All @@ -22,21 +22,21 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_tiny-fe9cc245.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_tiny-a29eeb4a.weights.h5&src=0",
},
"textnet_small": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_small-29c39c82.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_small-1c2df0e3.weights.h5&src=0",
},
"textnet_base": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/textnet_base-168aa82c.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_base-8b4b89bc.weights.h5&src=0",
},
}

Expand Down Expand Up @@ -113,7 +113,11 @@ def _textnet(
model = TextNet(cfg=_cfg, **kwargs)
# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs[arch]["url"])
# The number of classes is not the same as the number of classes in the pretrained model =>
# skip the mismatching layers for fine tuning
load_pretrained_params(
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
)

return model

Expand Down
12 changes: 8 additions & 4 deletions doctr/models/classification/vgg/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple

from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from keras import layers
from keras.models import Sequential

from doctr.datasets import VOCABS

Expand All @@ -22,7 +22,7 @@
"std": (1.0, 1.0, 1.0),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/vgg16_bn_r-c5836cea.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vgg16_bn_r-b4d69212.weights.h5&src=0",
},
}

Expand Down Expand Up @@ -83,7 +83,11 @@ def _vgg(
model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs)
# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs[arch]["url"])
# The number of classes is not the same as the number of classes in the pretrained model =>
# skip the mismatching layers for fine tuning
load_pretrained_params(
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["num_classes"] != len(default_cfgs[arch]["classes"])
)

return model

Expand Down
Loading

0 comments on commit dccc26b

Please sign in to comment.