Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Keras v3 [DRAFT] #1749

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/using_doctr/using_model_export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Advantages:
.. code:: python3

import tensorflow as tf
from tensorflow.keras import mixed_precision
from keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')
predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True)

Expand Down
15 changes: 0 additions & 15 deletions doctr/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,6 @@
logging.info("Disabling PyTorch because USE_TF is set")
_torch_available = False

# Compatibility fix to make sure tensorflow.keras stays at Keras 2
if "TF_USE_LEGACY_KERAS" not in os.environ:
os.environ["TF_USE_LEGACY_KERAS"] = "1"

elif os.environ["TF_USE_LEGACY_KERAS"] != "1":
raise ValueError(
"docTR is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. "
)


def ensure_keras_v2() -> None: # pragma: no cover
if not os.environ.get("TF_USE_LEGACY_KERAS") == "1":
os.environ["TF_USE_LEGACY_KERAS"] = "1"


if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
_tf_available = importlib.util.find_spec("tensorflow") is not None
Expand Down Expand Up @@ -79,7 +65,6 @@ def ensure_keras_v2() -> None: # pragma: no cover
_tf_available = False
else:
logging.info(f"TensorFlow version {_tf_version} available.")
ensure_keras_v2()
import tensorflow as tf

# Enable eager execution - this is required for some models to work properly
Expand Down
2 changes: 1 addition & 1 deletion doctr/io/image/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import numpy as np
import tensorflow as tf
from keras.utils import img_to_array
from PIL import Image
from tensorflow.keras.utils import img_to_array

from doctr.utils.common_types import AbstractPath

Expand Down
5 changes: 2 additions & 3 deletions doctr/models/classification/magc_resnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from typing import Any, Dict, List, Optional, Tuple

import tensorflow as tf
from tensorflow.keras import activations, layers
from tensorflow.keras.models import Sequential
from keras import Sequential, activations, layers

from doctr.datasets import VOCABS

Expand All @@ -26,7 +25,7 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/magc_resnet31-16aa7d71.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/magc_resnet31-6c266055.weights.h5",
},
}

Expand Down
15 changes: 7 additions & 8 deletions doctr/models/classification/mobilenet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from keras import Sequential, layers

from ....datasets import VOCABS
from ...utils import _build_model, conv_sequence, load_pretrained_params
Expand All @@ -32,42 +31,42 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large-d857506e.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/mobilenet_v3_large-d857506e.weights.h5",
},
"mobilenet_v3_large_r": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_large_r-eef2e3c6.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/mobilenet_v3_large_r-eef2e3c6.weights.h5",
},
"mobilenet_v3_small": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small-3fcebad7.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/mobilenet_v3_small-3fcebad7.weights.h5",
},
"mobilenet_v3_small_r": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_r-dd50218d.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/mobilenet_v3_small_r-dd50218d.weights.h5",
},
"mobilenet_v3_small_crop_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (128, 128, 3),
"classes": [0, -90, 180, 90],
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_crop_orientation-ef019b6b.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/mobilenet_v3_small_crop_orientation-ef019b6b.weights.h5",
},
"mobilenet_v3_small_page_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (512, 512, 3),
"classes": [0, -90, 180, 90],
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/mobilenet_v3_small_page_orientation-0071d55d.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/mobilenet_v3_small_page_orientation-0071d55d.weights.h5",
},
}

Expand Down
2 changes: 1 addition & 1 deletion doctr/models/classification/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import tensorflow as tf
from tensorflow.keras import Model
from keras import Model

from doctr.models.preprocessor import PreProcessor
from doctr.utils.repr import NestedObject
Expand Down
16 changes: 7 additions & 9 deletions doctr/models/classification/resnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Sequential
from keras import Sequential, applications, layers

from doctr.datasets import VOCABS

Expand All @@ -24,35 +22,35 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet18-f42d3854.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/resnet18-4138682e.weights.h5",
},
"resnet31": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet31-ab75f78c.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/resnet31-61808f41.weights.h5",
},
"resnet34": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34-03967df9.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/resnet34-2288ee52.weights.h5",
},
"resnet50": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet50-82358f34.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/resnet50-82358f34.weights.h5",
},
"resnet34_wide": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/resnet34_wide-b18fdf79.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/resnet34_wide-4c788e90.weights.h5",
},
}

Expand Down Expand Up @@ -350,7 +348,7 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
_cfg["input_shape"] = kwargs["input_shape"]
kwargs.pop("classes")

model = ResNet50(
model = applications.ResNet50(
weights=None,
include_top=True,
pooling=True,
Expand Down
8 changes: 4 additions & 4 deletions doctr/models/classification/textnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple

from tensorflow.keras import Sequential, layers
from keras import Sequential, layers

from doctr.datasets import VOCABS

Expand All @@ -22,21 +22,21 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_tiny-a29eeb4a.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/textnet_tiny-99fb9158.weights.h5",
},
"textnet_small": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_small-1c2df0e3.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/textnet_small-44072f65.weights.h5",
},
"textnet_base": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/textnet_base-8b4b89bc.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/textnet_base-a92df1c0.weights.h5",
},
}

Expand Down
5 changes: 2 additions & 3 deletions doctr/models/classification/vgg/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple

from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from keras import Sequential, layers

from doctr.datasets import VOCABS

Expand All @@ -22,7 +21,7 @@
"std": (1.0, 1.0, 1.0),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vgg16_bn_r-b4d69212.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/vgg16_bn_r-b4d69212.weights.h5",
},
}

Expand Down
3 changes: 1 addition & 2 deletions doctr/models/classification/vit/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from torch import nn

from doctr.datasets import VOCABS
from doctr.models.modules.transformer import EncoderBlock
from doctr.models.modules.vision_transformer.pytorch import PatchEmbedding
from doctr.models.modules import EncoderBlock, PatchEmbedding

from ...utils.pytorch import load_pretrained_params

Expand Down
6 changes: 3 additions & 3 deletions doctr/models/classification/vit/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Dict, Optional, Tuple

import tensorflow as tf
from tensorflow.keras import Sequential, layers
from keras import Sequential, layers

from doctr.datasets import VOCABS
from doctr.models.modules.transformer import EncoderBlock
Expand All @@ -25,14 +25,14 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_s-69bc459e.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/vit_s-d68b3d5b.weights.h5",
},
"vit_b": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/vit_b-c64705bd.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/vit_b-f01181f0.weights.h5",
},
}

Expand Down
10 changes: 5 additions & 5 deletions doctr/models/detection/differentiable_binarization/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, Sequential, layers, losses
from tensorflow.keras.applications import ResNet50
from keras import Model, Sequential, applications, layers, losses

from doctr.file_utils import CLASS_NAME
from doctr.models.utils import (
Expand All @@ -34,13 +33,13 @@
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"input_shape": (1024, 1024, 3),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_resnet50-649fa22b.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/db_resnet50-fe92475b.weights.h5",
},
"db_mobilenet_v3_large": {
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"input_shape": (1024, 1024, 3),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/db_mobilenet_v3_large-ee2e1dbe.weights.h5&src=0",
"url": None,
},
}

Expand Down Expand Up @@ -356,6 +355,7 @@ def _db_mobilenet(
# Build the model
model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
_build_model(model)

# Load pretrained parameters
if pretrained:
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
Expand Down Expand Up @@ -390,7 +390,7 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
return _db_resnet(
"db_resnet50",
pretrained,
ResNet50,
applications.ResNet50,
["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"],
**kwargs,
)
Expand Down
8 changes: 4 additions & 4 deletions doctr/models/detection/fast/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, Sequential, layers
from keras import Model, Sequential, layers

from doctr.file_utils import CLASS_NAME
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, _build_model, load_pretrained_params
Expand All @@ -28,19 +28,19 @@
"input_shape": (1024, 1024, 3),
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_tiny-d7379d7b.weights.h5&src=0",
"url": None,
},
"fast_small": {
"input_shape": (1024, 1024, 3),
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_small-44b27eb6.weights.h5&src=0",
"url": None,
},
"fast_base": {
"input_shape": (1024, 1024, 3),
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_base-f2c6c736.weights.h5&src=0",
"url": None,
},
}

Expand Down
8 changes: 4 additions & 4 deletions doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, Sequential, layers, losses
from keras import Model, Sequential, layers, losses

from doctr.file_utils import CLASS_NAME
from doctr.models.classification import resnet18, resnet34, resnet50
Expand All @@ -32,19 +32,19 @@
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"input_shape": (1024, 1024, 3),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet18-615a82c5.weights.h5&src=0",
"url": None,
},
"linknet_resnet34": {
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"input_shape": (1024, 1024, 3),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet34-9d772be5.weights.h5&src=0",
"url": None,
},
"linknet_resnet50": {
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"input_shape": (1024, 1024, 3),
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet50-6bf6c8b5.weights.h5&src=0",
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/linknet_resnet50-fdea2b5f.weights.h5",
},
}

Expand Down
Loading
Loading