From 2f9f50e5a7ec5ea4045fd4baff5aa93c9d2ed906 Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Wed, 9 Oct 2024 17:15:13 +0200 Subject: [PATCH] [TF] Move model building & unify train scripts (#1744) --- .github/workflows/references.yml | 12 ++++++------ .../classification/magc_resnet/tensorflow.py | 6 ++++-- .../classification/mobilenet/tensorflow.py | 4 +++- doctr/models/classification/resnet/tensorflow.py | 5 ++++- .../models/classification/textnet/tensorflow.py | 4 +++- doctr/models/classification/vgg/tensorflow.py | 4 +++- doctr/models/classification/vit/tensorflow.py | 4 +++- .../differentiable_binarization/tensorflow.py | 11 ++++++++++- doctr/models/detection/fast/tensorflow.py | 7 +++---- doctr/models/detection/linknet/tensorflow.py | 14 +++++++++++--- doctr/models/factory/hub.py | 6 ------ doctr/models/preprocessor/tensorflow.py | 2 +- doctr/models/recognition/crnn/tensorflow.py | 3 ++- doctr/models/recognition/master/tensorflow.py | 4 +++- doctr/models/recognition/parseq/tensorflow.py | 4 +++- doctr/models/recognition/sar/tensorflow.py | 3 ++- doctr/models/recognition/vitstr/tensorflow.py | 4 +++- doctr/models/utils/tensorflow.py | 16 +++++++++++----- references/classification/README.md | 4 ++-- .../classification/train_pytorch_orientation.py | 6 +++--- .../classification/train_tensorflow_character.py | 2 -- .../train_tensorflow_orientation.py | 10 ++++------ references/detection/README.md | 4 ++-- references/detection/evaluate_tensorflow.py | 2 +- references/detection/train_pytorch.py | 4 ++-- references/detection/train_tensorflow.py | 12 ++---------- references/detection/utils.py | 8 -------- references/recognition/README.md | 2 +- references/recognition/evaluate_tensorflow.py | 2 +- references/recognition/train_tensorflow.py | 2 -- 30 files changed, 93 insertions(+), 78 deletions(-) diff --git a/.github/workflows/references.yml b/.github/workflows/references.yml index 56856ba1d3..f79784244a 100644 --- a/.github/workflows/references.yml +++ b/.github/workflows/references.yml @@ -114,16 +114,16 @@ jobs: unzip toy_recogition_set-036a4d80.zip -d reco_set - if: matrix.framework == 'tensorflow' name: Train for a short epoch (TF) (document orientation) - run: python references/classification/train_tensorflow_orientation.py ./det_set ./det_set resnet18 page -b 2 --epochs 1 + run: python references/classification/train_tensorflow_orientation.py resnet18 --type page --train_path ./det_set --val_path ./det_set -b 2 --epochs 1 - if: matrix.framework == 'pytorch' name: Train for a short epoch (PT) (document orientation) - run: python references/classification/train_pytorch_orientation.py ./det_set ./det_set resnet18 page -b 2 --epochs 1 + run: python references/classification/train_pytorch_orientation.py resnet18 --type page --train_path ./det_set --val_path ./det_set -b 2 --epochs 1 - if: matrix.framework == 'tensorflow' name: Train for a short epoch (TF) (crop orientation) - run: python references/classification/train_tensorflow_orientation.py ./reco_set ./reco_set resnet18 crop -b 4 --epochs 1 + run: python references/classification/train_tensorflow_orientation.py resnet18 --type crop --train_path ./reco_set --val_path ./reco_set -b 4 --epochs 1 - if: matrix.framework == 'pytorch' name: Train for a short epoch (PT) (crop orientation) - run: python references/classification/train_pytorch_orientation.py ./reco_set ./reco_set resnet18 crop -b 4 --epochs 1 + run: python references/classification/train_pytorch_orientation.py resnet18 --type crop --train_path ./reco_set --val_path ./reco_set -b 4 --epochs 1 train-text-recognition: runs-on: ${{ matrix.os }} @@ -318,10 +318,10 @@ jobs: unzip toy_detection_set-bbbb4243.zip -d det_set - if: matrix.framework == 'tensorflow' name: Train for a short epoch (TF) - run: python references/detection/train_tensorflow.py --train_path ./det_set --val_path ./det_set linknet_resnet18 -b 2 --epochs 1 + run: python references/detection/train_tensorflow.py linknet_resnet18 --train_path ./det_set --val_path ./det_set -b 2 --epochs 1 - if: matrix.framework == 'pytorch' name: Train for a short epoch (PT) - run: python references/detection/train_pytorch.py ./det_set ./det_set db_mobilenet_v3_large -b 2 --epochs 1 + run: python references/detection/train_pytorch.py db_mobilenet_v3_large --train_path ./det_set --val_path ./det_set -b 2 --epochs 1 evaluate-text-detection: runs-on: ${{ matrix.os }} diff --git a/doctr/models/classification/magc_resnet/tensorflow.py b/doctr/models/classification/magc_resnet/tensorflow.py index fc7678f661..d920ca44a4 100644 --- a/doctr/models/classification/magc_resnet/tensorflow.py +++ b/doctr/models/classification/magc_resnet/tensorflow.py @@ -14,7 +14,7 @@ from doctr.datasets import VOCABS -from ...utils import load_pretrained_params +from ...utils import _build_model, load_pretrained_params from ..resnet.tensorflow import ResNet __all__ = ["magc_resnet31"] @@ -115,7 +115,7 @@ def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: # Context modeling: B, H, W, C -> B, 1, 1, C context = self.context_modeling(inputs) # Transform: B, 1, 1, C -> B, 1, 1, C - transformed = self.transform(context) + transformed = self.transform(context, **kwargs) return inputs + transformed @@ -152,6 +152,8 @@ def _magc_resnet( cfg=_cfg, **kwargs, ) + _build_model(model) + # Load pretrained parameters if pretrained: # The number of classes is not the same as the number of classes in the pretrained model => diff --git a/doctr/models/classification/mobilenet/tensorflow.py b/doctr/models/classification/mobilenet/tensorflow.py index ff57c221dc..ae3535d947 100644 --- a/doctr/models/classification/mobilenet/tensorflow.py +++ b/doctr/models/classification/mobilenet/tensorflow.py @@ -13,7 +13,7 @@ from tensorflow.keras.models import Sequential from ....datasets import VOCABS -from ...utils import conv_sequence, load_pretrained_params +from ...utils import _build_model, conv_sequence, load_pretrained_params __all__ = [ "MobileNetV3", @@ -295,6 +295,8 @@ def _mobilenet_v3(arch: str, pretrained: bool, rect_strides: bool = False, **kwa cfg=_cfg, **kwargs, ) + _build_model(model) + # Load pretrained parameters if pretrained: # The number of classes is not the same as the number of classes in the pretrained model => diff --git a/doctr/models/classification/resnet/tensorflow.py b/doctr/models/classification/resnet/tensorflow.py index 364b03c3a2..662a43c3a0 100644 --- a/doctr/models/classification/resnet/tensorflow.py +++ b/doctr/models/classification/resnet/tensorflow.py @@ -13,7 +13,7 @@ from doctr.datasets import VOCABS -from ...utils import conv_sequence, load_pretrained_params +from ...utils import _build_model, conv_sequence, load_pretrained_params __all__ = ["ResNet", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide"] @@ -210,6 +210,8 @@ def _resnet( model = ResNet( num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling, origin_stem, cfg=_cfg, **kwargs ) + _build_model(model) + # Load pretrained parameters if pretrained: # The number of classes is not the same as the number of classes in the pretrained model => @@ -358,6 +360,7 @@ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet: ) model.cfg = _cfg + _build_model(model) # Load pretrained parameters if pretrained: diff --git a/doctr/models/classification/textnet/tensorflow.py b/doctr/models/classification/textnet/tensorflow.py index b0bb9a7205..e5e6105a7e 100644 --- a/doctr/models/classification/textnet/tensorflow.py +++ b/doctr/models/classification/textnet/tensorflow.py @@ -12,7 +12,7 @@ from doctr.datasets import VOCABS from ...modules.layers.tensorflow import FASTConvLayer -from ...utils import conv_sequence, load_pretrained_params +from ...utils import _build_model, conv_sequence, load_pretrained_params __all__ = ["textnet_tiny", "textnet_small", "textnet_base"] @@ -111,6 +111,8 @@ def _textnet( # Build the model model = TextNet(cfg=_cfg, **kwargs) + _build_model(model) + # Load pretrained parameters if pretrained: # The number of classes is not the same as the number of classes in the pretrained model => diff --git a/doctr/models/classification/vgg/tensorflow.py b/doctr/models/classification/vgg/tensorflow.py index 9ecdabd040..c42e369bcd 100644 --- a/doctr/models/classification/vgg/tensorflow.py +++ b/doctr/models/classification/vgg/tensorflow.py @@ -11,7 +11,7 @@ from doctr.datasets import VOCABS -from ...utils import conv_sequence, load_pretrained_params +from ...utils import _build_model, conv_sequence, load_pretrained_params __all__ = ["VGG", "vgg16_bn_r"] @@ -81,6 +81,8 @@ def _vgg( # Build the model model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs) + _build_model(model) + # Load pretrained parameters if pretrained: # The number of classes is not the same as the number of classes in the pretrained model => diff --git a/doctr/models/classification/vit/tensorflow.py b/doctr/models/classification/vit/tensorflow.py index 8531193939..386065bca6 100644 --- a/doctr/models/classification/vit/tensorflow.py +++ b/doctr/models/classification/vit/tensorflow.py @@ -14,7 +14,7 @@ from doctr.models.modules.vision_transformer.tensorflow import PatchEmbedding from doctr.utils.repr import NestedObject -from ...utils import load_pretrained_params +from ...utils import _build_model, load_pretrained_params __all__ = ["vit_s", "vit_b"] @@ -121,6 +121,8 @@ def _vit( # Build the model model = VisionTransformer(cfg=_cfg, **kwargs) + _build_model(model) + # Load pretrained parameters if pretrained: # The number of classes is not the same as the number of classes in the pretrained model => diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 45e522b872..b0ca1f08e5 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -14,7 +14,13 @@ from tensorflow.keras.applications import ResNet50 from doctr.file_utils import CLASS_NAME -from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params +from doctr.models.utils import ( + IntermediateLayerGetter, + _bf16_to_float32, + _build_model, + conv_sequence, + load_pretrained_params, +) from doctr.utils.repr import NestedObject from ...classification import mobilenet_v3_large @@ -304,6 +310,8 @@ def _db_resnet( # Build the model model = DBNet(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) + # Load pretrained parameters if pretrained: # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning @@ -347,6 +355,7 @@ def _db_mobilenet( # Build the model model = DBNet(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) # Load pretrained parameters if pretrained: # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning diff --git a/doctr/models/detection/fast/tensorflow.py b/doctr/models/detection/fast/tensorflow.py index 91d6c8cc4d..b0043494ed 100644 --- a/doctr/models/detection/fast/tensorflow.py +++ b/doctr/models/detection/fast/tensorflow.py @@ -13,7 +13,7 @@ from tensorflow.keras import Model, Sequential, layers from doctr.file_utils import CLASS_NAME -from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, load_pretrained_params +from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, _build_model, load_pretrained_params from doctr.utils.repr import NestedObject from ...classification import textnet_base, textnet_small, textnet_tiny @@ -333,6 +333,8 @@ def _fast( # Build the model model = FAST(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) + # Load pretrained parameters if pretrained: # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning @@ -342,9 +344,6 @@ def _fast( skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]), ) - # Build the model for reparameterization to access the layers - _ = model(tf.random.uniform(shape=[1, *_cfg["input_shape"]], maxval=1, dtype=tf.float32), training=False) - return model diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index df8233cf20..9c991c6f4c 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -14,7 +14,13 @@ from doctr.file_utils import CLASS_NAME from doctr.models.classification import resnet18, resnet34, resnet50 -from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, conv_sequence, load_pretrained_params +from doctr.models.utils import ( + IntermediateLayerGetter, + _bf16_to_float32, + _build_model, + conv_sequence, + load_pretrained_params, +) from doctr.utils.repr import NestedObject from .base import LinkNetPostProcessor, _LinkNet @@ -79,10 +85,10 @@ def __init__( for in_chan, out_chan, s, in_shape in zip(i_chans, o_chans, strides, in_shapes[::-1]) ] - def call(self, x: List[tf.Tensor]) -> tf.Tensor: + def call(self, x: List[tf.Tensor], **kwargs: Any) -> tf.Tensor: out = 0 for decoder, fmap in zip(self.decoders, x[::-1]): - out = decoder(out + fmap) + out = decoder(out + fmap, **kwargs) return out def extra_repr(self) -> str: @@ -274,6 +280,8 @@ def _linknet( # Build the model model = LinkNet(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) + # Load pretrained parameters if pretrained: # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning diff --git a/doctr/models/factory/hub.py b/doctr/models/factory/hub.py index b5844dd30b..dd9fc5d776 100644 --- a/doctr/models/factory/hub.py +++ b/doctr/models/factory/hub.py @@ -27,8 +27,6 @@ if is_torch_available(): import torch -elif is_tf_available(): - import tensorflow as tf __all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"] @@ -76,8 +74,6 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task torch.save(model.state_dict(), weights_path) elif is_tf_available(): weights_path = save_directory / "tf_model.weights.h5" - # NOTE: `model.build` is not an option because it doesn't runs in eager mode - _ = model(tf.ones((1, *model.cfg["input_shape"])), training=False) model.save_weights(str(weights_path)) config_path = save_directory / "config.json" @@ -229,8 +225,6 @@ def from_hub(repo_id: str, **kwargs: Any): model.load_state_dict(state_dict) else: # tf weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs) - # NOTE: `model.build` is not an option because it doesn't runs in eager mode - _ = model(tf.ones((1, *model.cfg["input_shape"])), training=False) model.load_weights(weights) return model diff --git a/doctr/models/preprocessor/tensorflow.py b/doctr/models/preprocessor/tensorflow.py index 5a211004f3..85e06fca3e 100644 --- a/doctr/models/preprocessor/tensorflow.py +++ b/doctr/models/preprocessor/tensorflow.py @@ -41,7 +41,7 @@ def __init__( self.resize = Resize(output_size, **kwargs) # Perform the division by 255 at the same time self.normalize = Normalize(mean, std) - self._runs_on_cuda = tf.test.is_gpu_available() + self._runs_on_cuda = tf.config.list_physical_devices("GPU") != [] def batch_inputs(self, samples: List[tf.Tensor]) -> List[tf.Tensor]: """Gather samples into batches for inference purposes diff --git a/doctr/models/recognition/crnn/tensorflow.py b/doctr/models/recognition/crnn/tensorflow.py index fb5cb72dff..9f74882673 100644 --- a/doctr/models/recognition/crnn/tensorflow.py +++ b/doctr/models/recognition/crnn/tensorflow.py @@ -13,7 +13,7 @@ from doctr.datasets import VOCABS from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r -from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params from ..core import RecognitionModel, RecognitionPostProcessor __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"] @@ -245,6 +245,7 @@ def _crnn( # Build the model model = CRNN(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) # Load pretrained parameters if pretrained: # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning diff --git a/doctr/models/recognition/master/tensorflow.py b/doctr/models/recognition/master/tensorflow.py index 42cd216b2c..e01c089012 100644 --- a/doctr/models/recognition/master/tensorflow.py +++ b/doctr/models/recognition/master/tensorflow.py @@ -13,7 +13,7 @@ from doctr.models.classification import magc_resnet31 from doctr.models.modules.transformer import Decoder, PositionalEncoding -from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params from .base import _MASTER, _MASTERPostProcessor __all__ = ["MASTER", "master"] @@ -290,6 +290,8 @@ def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool cfg=_cfg, **kwargs, ) + _build_model(model) + # Load pretrained parameters if pretrained: # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py index b0e21a50d6..d8c54527be 100644 --- a/doctr/models/recognition/parseq/tensorflow.py +++ b/doctr/models/recognition/parseq/tensorflow.py @@ -16,7 +16,7 @@ from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward from ...classification import vit_s -from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params from .base import _PARSeq, _PARSeqPostProcessor __all__ = ["PARSeq", "parseq"] @@ -473,6 +473,8 @@ def _parseq( # Build the model model = PARSeq(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) + # Load pretrained parameters if pretrained: # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning diff --git a/doctr/models/recognition/sar/tensorflow.py b/doctr/models/recognition/sar/tensorflow.py index 89e93ea51e..bcb0b207ef 100644 --- a/doctr/models/recognition/sar/tensorflow.py +++ b/doctr/models/recognition/sar/tensorflow.py @@ -13,7 +13,7 @@ from doctr.utils.repr import NestedObject from ...classification import resnet31 -from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params from ..core import RecognitionModel, RecognitionPostProcessor __all__ = ["SAR", "sar_resnet31"] @@ -392,6 +392,7 @@ def _sar( # Build the model model = SAR(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) # Load pretrained parameters if pretrained: # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning diff --git a/doctr/models/recognition/vitstr/tensorflow.py b/doctr/models/recognition/vitstr/tensorflow.py index 6b38cf7548..9b121171f8 100644 --- a/doctr/models/recognition/vitstr/tensorflow.py +++ b/doctr/models/recognition/vitstr/tensorflow.py @@ -12,7 +12,7 @@ from doctr.datasets import VOCABS from ...classification import vit_b, vit_s -from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params +from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params from .base import _ViTSTR, _ViTSTRPostProcessor __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"] @@ -216,6 +216,8 @@ def _vitstr( # Build the model model = ViTSTR(feat_extractor, cfg=_cfg, **kwargs) + _build_model(model) + # Load pretrained parameters if pretrained: # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py index 6f7dc14ab3..c04a4b2893 100644 --- a/doctr/models/utils/tensorflow.py +++ b/doctr/models/utils/tensorflow.py @@ -17,6 +17,7 @@ __all__ = [ "load_pretrained_params", + "_build_model", "conv_sequence", "IntermediateLayerGetter", "export_model_to_onnx", @@ -34,6 +35,16 @@ def _bf16_to_float32(x: tf.Tensor) -> tf.Tensor: return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x +def _build_model(model: Model): + """Build a model by calling it once with dummy input + + Args: + ---- + model: the model to be built + """ + model(tf.zeros((1, *model.cfg["input_shape"])), training=False) + + def load_pretrained_params( model: Model, url: Optional[str] = None, @@ -58,11 +69,6 @@ def load_pretrained_params( logging.warning("Invalid model URL, using default initialization.") else: archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs) - - # Build the model - # NOTE: `model.build` is not an option because it doesn't runs in eager mode - _ = model(tf.ones((1, *model.cfg["input_shape"])), training=False) - # Load weights model.load_weights(archive_path, skip_mismatch=skip_mismatch) diff --git a/references/classification/README.md b/references/classification/README.md index 6646b0d8ca..885cc0b565 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -30,13 +30,13 @@ python references/classification/train_pytorch_character.py mobilenet_v3_large - You can start your training in TensorFlow: ```shell -python references/classification/train_tensorflow_orientation.py path/to/your/train_set path/to/your/val_set resnet18 page --epochs 5 +python references/classification/train_tensorflow_orientation.py resnet18 --type page --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 ``` or PyTorch: ```shell -python references/classification/train_pytorch_orientation.py path/to/your/train_set path/to/your/val_set resnet18 page --epochs 5 +python references/classification/train_pytorch_orientation.py resnet18 --type page --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 ``` The type can be either `page` for document images or `crop` for word crops. diff --git a/references/classification/train_pytorch_orientation.py b/references/classification/train_pytorch_orientation.py index 46c77d4c38..8324f0aa37 100644 --- a/references/classification/train_pytorch_orientation.py +++ b/references/classification/train_pytorch_orientation.py @@ -375,10 +375,10 @@ def parse_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("train_path", type=str, help="path to training data folder") - parser.add_argument("val_path", type=str, help="path to validation data folder") parser.add_argument("arch", type=str, help="classification model to train") - parser.add_argument("type", type=str, choices=["page", "crop"], help="type of data to train on") + parser.add_argument("--type", type=str, required=True, choices=["page", "crop"], help="type of data to train on") + parser.add_argument("--train_path", type=str, required=True, help="path to training data folder") + parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder") parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on") parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training") diff --git a/references/classification/train_tensorflow_character.py b/references/classification/train_tensorflow_character.py index d3b6e16a0c..0b1b648d93 100644 --- a/references/classification/train_tensorflow_character.py +++ b/references/classification/train_tensorflow_character.py @@ -185,8 +185,6 @@ def main(args): # Resume weights if isinstance(args.resume, str): - # Build the model first to load the weights - _ = model(tf.zeros((1, args.input_size, args.input_size, 3)), training=False) model.load_weights(args.resume) batch_transforms = T.Compose([ diff --git a/references/classification/train_tensorflow_orientation.py b/references/classification/train_tensorflow_orientation.py index 00cfe98add..297a5674f4 100644 --- a/references/classification/train_tensorflow_orientation.py +++ b/references/classification/train_tensorflow_orientation.py @@ -196,8 +196,6 @@ def main(args): # Resume weights if isinstance(args.resume, str): - # Build the model first to load the weights - _ = model(tf.zeros((1, *input_size, 3)), training=False) model.load_weights(args.resume) batch_transforms = T.Compose([ @@ -340,7 +338,7 @@ def main(args): if args.export_onnx: print("Exporting model to ONNX...") - if args.arch == "vit_b": + if args.arch in ["vit_s", "vit_b"]: # fixed batch size for vit dummy_input = [tf.TensorSpec([1, *(input_size), 3], tf.float32, name="input")] else: @@ -358,10 +356,10 @@ def parse_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("train_path", type=str, help="path to training data folder") - parser.add_argument("val_path", type=str, help="path to validation data folder") parser.add_argument("arch", type=str, help="classification model to train") - parser.add_argument("type", type=str, choices=["page", "crop"], help="type of data to train on") + parser.add_argument("--type", type=str, required=True, choices=["page", "crop"], help="type of data to train on") + parser.add_argument("--train_path", type=str, help="path to training data folder") + parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder") parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on") parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training") diff --git a/references/detection/README.md b/references/detection/README.md index 7a07b4cb6b..35d1481877 100644 --- a/references/detection/README.md +++ b/references/detection/README.md @@ -16,13 +16,13 @@ pip install -r references/requirements.txt You can start your training in TensorFlow: ```shell -python references/detection/train_tensorflow.py path/to/your/train_set path/to/your/val_set db_resnet50 --epochs 5 +python references/detection/train_tensorflow.py db_resnet50 --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 ``` or PyTorch: ```shell -python references/detection/train_pytorch.py path/to/your/train_set path/to/your/val_set db_resnet50 --epochs 5 --device 0 +python references/detection/train_pytorch.py db_resnet50 --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 ``` ## Data format diff --git a/references/detection/evaluate_tensorflow.py b/references/detection/evaluate_tensorflow.py index c224e07a91..a2c5bbe49c 100644 --- a/references/detection/evaluate_tensorflow.py +++ b/references/detection/evaluate_tensorflow.py @@ -40,7 +40,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric): for images, targets in tqdm(val_loader): images = batch_transforms(images) targets = [{CLASS_NAME: t} for t in targets] - out = model(images, targets, training=False, return_preds=True) + out = model(images, target=targets, training=False, return_preds=True) # Compute metric loc_preds = out["preds"] for target, loc_pred in zip(targets, loc_preds): diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index 0c30925146..091d257898 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -427,9 +427,9 @@ def parse_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("train_path", type=str, help="path to training data folder") - parser.add_argument("val_path", type=str, help="path to validation data folder") parser.add_argument("arch", type=str, help="text-detection model to train") + parser.add_argument("--train_path", type=str, required=True, help="path to training data folder") + parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder") parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on") parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training") diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index 0a535cd7cd..f054879e8f 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -31,7 +31,7 @@ from doctr.datasets import DataLoader, DetectionDataset from doctr.models import detection from doctr.utils.metrics import LocalizationConfusion -from utils import EarlyStopper, load_backbone, plot_recorder, plot_samples +from utils import EarlyStopper, plot_recorder, plot_samples def record_lr( @@ -193,15 +193,8 @@ def main(args): # Resume weights if isinstance(args.resume, str): - # Build the model first to load the weights - _ = model(tf.zeros((1, args.input_size, args.input_size, 3)), training=False) model.load_weights(args.resume) - if isinstance(args.pretrained_backbone, str): - print("Loading backbone weights.") - model = load_backbone(model, args.pretrained_backbone) - print("Done.") - # Metrics val_metric = LocalizationConfusion(use_polygons=args.rotation and not args.eval_straight) @@ -411,7 +404,7 @@ def parse_args(): parser.add_argument("arch", type=str, help="text-detection model to train") parser.add_argument("--train_path", type=str, required=True, help="path to training data folder") - parser.add_argument("--val_path", type=str, help="path to validation data folder") + parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder") parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on") parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training") @@ -421,7 +414,6 @@ def parse_args(): parser.add_argument("--input_size", type=int, default=1024, help="model input size, H = W") parser.add_argument("--lr", type=float, default=0.001, help="learning rate for the optimizer (Adam)") parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint") - parser.add_argument("--pretrained-backbone", type=str, default=None, help="Path to your backbone weights") parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop") parser.add_argument( "--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning" diff --git a/references/detection/utils.py b/references/detection/utils.py index 7983ee4d51..1a84f2340d 100644 --- a/references/detection/utils.py +++ b/references/detection/utils.py @@ -3,7 +3,6 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -import pickle from typing import Dict, List import cv2 @@ -86,13 +85,6 @@ def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> N plt.show(**kwargs) -def load_backbone(model, weights_path): - pretrained_backbone_weights = pickle.load(open(weights_path, "rb")) - model.feat_extractor.set_weights(pretrained_backbone_weights[0]) - model.fpn.set_weights(pretrained_backbone_weights[1]) - return model - - class EarlyStopper: def __init__(self, patience: int = 5, min_delta: float = 0.01): self.patience = patience diff --git a/references/recognition/README.md b/references/recognition/README.md index 5823030120..9087cbc210 100644 --- a/references/recognition/README.md +++ b/references/recognition/README.md @@ -22,7 +22,7 @@ python references/recognition/train_tensorflow.py crnn_vgg16_bn --train_path pat or PyTorch: ```shell -python references/recognition/train_pytorch.py crnn_vgg16_bn --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 --device 0 +python references/recognition/train_pytorch.py crnn_vgg16_bn --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 ``` ### Multi-GPU support (PyTorch only - Experimental) diff --git a/references/recognition/evaluate_tensorflow.py b/references/recognition/evaluate_tensorflow.py index dc034d333f..b6ca50b516 100644 --- a/references/recognition/evaluate_tensorflow.py +++ b/references/recognition/evaluate_tensorflow.py @@ -38,7 +38,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric): for images, targets in tqdm(val_iter): try: images = batch_transforms(images) - out = model(images, targets, return_preds=True, training=False) + out = model(images, target=targets, return_preds=True, training=False) # Compute metric if len(out["preds"]): words, _ = zip(*out["preds"]) diff --git a/references/recognition/train_tensorflow.py b/references/recognition/train_tensorflow.py index c12752a3e1..348f3a3869 100644 --- a/references/recognition/train_tensorflow.py +++ b/references/recognition/train_tensorflow.py @@ -193,8 +193,6 @@ def main(args): ) # Resume weights if isinstance(args.resume, str): - # Build the model first to load the weights - _ = model(tf.zeros((1, args.input_size, 4 * args.input_size, 3)), training=False) model.load_weights(args.resume) # Metrics