Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Oct 24, 2024
1 parent 93af653 commit 9dca02f
Show file tree
Hide file tree
Showing 32 changed files with 89 additions and 120 deletions.
2 changes: 1 addition & 1 deletion docs/source/using_doctr/using_model_export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Advantages:
.. code:: python3
import tensorflow as tf
from tensorflow.keras import mixed_precision
from keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')
predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True)
Expand Down
2 changes: 1 addition & 1 deletion doctr/io/image/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import numpy as np
import tensorflow as tf
from keras.utils import img_to_array
from PIL import Image
from tensorflow.keras.utils import img_to_array

from doctr.utils.common_types import AbstractPath

Expand Down
8 changes: 3 additions & 5 deletions doctr/models/classification/magc_resnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
from typing import Any, Dict, List, Optional, Tuple

import tensorflow as tf
from tensorflow.keras import activations, layers
from tensorflow.keras.models import Sequential
from keras import Sequential, activations, layers

from doctr.datasets import VOCABS

from ...utils import _build_model, load_pretrained_params
from ...utils import load_pretrained_params
from ..resnet.tensorflow import ResNet

__all__ = ["magc_resnet31"]
Expand All @@ -26,7 +25,7 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/magc_resnet31-6c266055.weights.h5",
},
}

Expand Down Expand Up @@ -152,7 +151,6 @@ def _magc_resnet(
cfg=_cfg,
**kwargs,
)
_build_model(model)

# Load pretrained parameters
if pretrained:
Expand Down
18 changes: 8 additions & 10 deletions doctr/models/classification/mobilenet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from keras import Sequential, layers

from ....datasets import VOCABS
from ...utils import _build_model, conv_sequence, load_pretrained_params
from ...utils import conv_sequence, load_pretrained_params

__all__ = [
"MobileNetV3",
Expand All @@ -32,42 +31,42 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/mobilenet_v3_large-d857506e.weights.h5",
},
"mobilenet_v3_large_r": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/mobilenet_v3_large_r-eef2e3c6.weights.h5",
},
"mobilenet_v3_small": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/mobilenet_v3_small-3fcebad7.weights.h5",
},
"mobilenet_v3_small_r": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/mobilenet_v3_small_r-dd50218d.weights.h5",
},
"mobilenet_v3_small_crop_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (128, 128, 3),
"classes": [0, -90, 180, 90],
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/mobilenet_v3_small_crop_orientation-ef019b6b.weights.h5",
},
"mobilenet_v3_small_page_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (512, 512, 3),
"classes": [0, -90, 180, 90],
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/mobilenet_v3_small_page_orientation-0071d55d.weights.h5",
},
}

Expand Down Expand Up @@ -295,7 +294,6 @@ def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwa
cfg=_cfg,
**kwargs,
)
_build_model(model)

# Load pretrained parameters
if pretrained:
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/classification/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import tensorflow as tf
from tensorflow.keras import Model
from keras import Model

from doctr.models.preprocessor import PreProcessor
from doctr.utils.repr import NestedObject
Expand Down
20 changes: 8 additions & 12 deletions doctr/models/classification/resnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@
from typing import Any, Callable, Dict, List, Optional, Tuple

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Sequential
from keras import Sequential, applications, layers

from doctr.datasets import VOCABS

from ...utils import _build_model, conv_sequence, load_pretrained_params
from ...utils import conv_sequence, load_pretrained_params

__all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide"]

Expand All @@ -24,35 +22,35 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/resnet18-4138682e.weights.h5",
},
"resnet31": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/resnet31-61808f41.weights.h5",
},
"resnet34": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/resnet34-2288ee52.weights.h5",
},
"resnet50": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/resnet50-82358f34.weights.h5",
},
"resnet34_wide": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/resnet34_wide-4c788e90.weights.h5",
},
}

Expand Down Expand Up @@ -210,7 +208,6 @@ def _resnet(
model = ResNet(
num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling, origin_stem, cfg=_cfg, **kwargs
)
_build_model(model)

# Load pretrained parameters
if pretrained:
Expand Down Expand Up @@ -350,7 +347,7 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
_cfg["input_shape"] = kwargs["input_shape"]
kwargs.pop("classes")

model = ResNet50(
model = applications.ResNet50(
weights=None,
include_top=True,
pooling=True,
Expand All @@ -360,7 +357,6 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
)

model.cfg = _cfg
_build_model(model)

# Load pretrained parameters
if pretrained:
Expand Down
11 changes: 5 additions & 6 deletions doctr/models/classification/textnet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple

from tensorflow.keras import Sequential, layers
from keras import Sequential, layers

from doctr.datasets import VOCABS

from ...modules.layers.tensorflow import FASTConvLayer
from ...utils import _build_model, conv_sequence, load_pretrained_params
from ...utils import conv_sequence, load_pretrained_params

__all__ = ["textnet_tiny", "textnet_small", "textnet_base"]

Expand All @@ -22,21 +22,21 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/textnet_tiny-99fb9158.weights.h5",
},
"textnet_small": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/textnet_small-44072f65.weights.h5",
},
"textnet_base": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/textnet_base-a92df1c0.weights.h5",
},
}

Expand Down Expand Up @@ -111,7 +111,6 @@ def _textnet(

# Build the model
model = TextNet(cfg=_cfg, **kwargs)
_build_model(model)

# Load pretrained parameters
if pretrained:
Expand Down
8 changes: 3 additions & 5 deletions doctr/models/classification/vgg/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple

from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from keras import Sequential, layers

from doctr.datasets import VOCABS

from ...utils import _build_model, conv_sequence, load_pretrained_params
from ...utils import conv_sequence, load_pretrained_params

__all__ = ["VGG", "vgg16_bn_r"]

Expand All @@ -22,7 +21,7 @@
"std": (1.0, 1.0, 1.0),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/vgg16_bn_r-b4d69212.weights.h5",
},
}

Expand Down Expand Up @@ -81,7 +80,6 @@ def _vgg(

# Build the model
model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs)
_build_model(model)

# Load pretrained parameters
if pretrained:
Expand Down
9 changes: 4 additions & 5 deletions doctr/models/classification/vit/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from typing import Any, Dict, Optional, Tuple

import tensorflow as tf
from tensorflow.keras import Sequential, layers
from keras import Sequential, layers

from doctr.datasets import VOCABS
from doctr.models.modules.transformer import EncoderBlock
from doctr.models.modules.vision_transformer.tensorflow import PatchEmbedding
from doctr.utils.repr import NestedObject

from ...utils import _build_model, load_pretrained_params
from ...utils import load_pretrained_params

__all__ = ["vit_s", "vit_b"]

Expand All @@ -25,14 +25,14 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/vit_s-d68b3d5b.weights.h5",
},
"vit_b": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 32, 3),
"classes": list(VOCABS["french"]),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/vit_b-f01181f0.weights.h5",
},
}

Expand Down Expand Up @@ -121,7 +121,6 @@ def _vit(

# Build the model
model = VisionTransformer(cfg=_cfg, **kwargs)
_build_model(model)

# Load pretrained parameters
if pretrained:
Expand Down
11 changes: 4 additions & 7 deletions doctr/models/detection/differentiable_binarization/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@

import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, Sequential, layers, losses
from tensorflow.keras.applications import ResNet50
from keras import Model, Sequential, applications, layers, losses

from doctr.file_utils import CLASS_NAME
from doctr.models.utils import (
IntermediateLayerGetter,
_bf16_to_float32,
_build_model,
conv_sequence,
load_pretrained_params,
)
Expand All @@ -34,7 +32,7 @@
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"input_shape": (1024, 1024, 3),
"url": None,
"url": "https://github.com/mindee/doctr/releases/download/v0.10.0/db_resnet50-fe92475b.weights.h5",
},
"db_mobilenet_v3_large": {
"mean": (0.798, 0.785, 0.772),
Expand Down Expand Up @@ -310,7 +308,6 @@ def _db_resnet(

# Build the model
model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
_build_model(model)

# Load pretrained parameters
if pretrained:
Expand Down Expand Up @@ -355,7 +352,7 @@ def _db_mobilenet(

# Build the model
model = DBNet(feat_extractor, cfg=_cfg, **kwargs)
_build_model(model)

# Load pretrained parameters
if pretrained:
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
Expand Down Expand Up @@ -390,7 +387,7 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
return _db_resnet(
"db_resnet50",
pretrained,
ResNet50,
applications.ResNet50,
["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"],
**kwargs,
)
Expand Down
Loading

0 comments on commit 9dca02f

Please sign in to comment.