From e5b3f462ac8defee22002a3581e0294bc49038cc Mon Sep 17 00:00:00 2001 From: Olivier Dulcy <106678676+odulcy-mindee@users.noreply.github.com> Date: Mon, 8 Jan 2024 17:29:52 +0100 Subject: [PATCH] feat: :sparkles: new TF Linknet Resnet checkpoints (#1424) --- doctr/models/detection/linknet/tensorflow.py | 82 ++++++++++---------- 1 file changed, 43 insertions(+), 39 deletions(-) diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index 6950e07220..cfb15b3108 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -27,39 +27,41 @@ "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.5.0/linknet_resnet18-a48e6ed3.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet18-b9ee56e6.zip&src=0", }, "linknet_resnet34": { "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/linknet_resnet34-bf30afb1.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet34-51909c56.zip&src=0", }, "linknet_resnet50": { "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "input_shape": (1024, 1024, 3), - "url": "https://doctr-static.mindee.com/models?id=v0.6.0/linknet_resnet50-cd299262.zip&src=0", + "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet50-ac9f3829.zip&src=0", }, } def decoder_block(in_chan: int, out_chan: int, stride: int, **kwargs: Any) -> Sequential: """Creates a LinkNet decoder block""" - return Sequential([ - *conv_sequence(in_chan // 4, "relu", True, kernel_size=1, **kwargs), - layers.Conv2DTranspose( - filters=in_chan // 4, - kernel_size=3, - strides=stride, - padding="same", - use_bias=False, - kernel_initializer="he_normal", - ), - layers.BatchNormalization(), - layers.Activation("relu"), - *conv_sequence(out_chan, "relu", True, kernel_size=1), - ]) + return Sequential( + [ + *conv_sequence(in_chan // 4, "relu", True, kernel_size=1, **kwargs), + layers.Conv2DTranspose( + filters=in_chan // 4, + kernel_size=3, + strides=stride, + padding="same", + use_bias=False, + kernel_initializer="he_normal", + ), + layers.BatchNormalization(), + layers.Activation("relu"), + *conv_sequence(out_chan, "relu", True, kernel_size=1), + ] + ) class LinkNetFPN(Model, NestedObject): @@ -129,28 +131,30 @@ def __init__( self.fpn = LinkNetFPN(fpn_channels, [_shape[1:] for _shape in self.feat_extractor.output_shape]) self.fpn.build(self.feat_extractor.output_shape) - self.classifier = Sequential([ - layers.Conv2DTranspose( - filters=32, - kernel_size=3, - strides=2, - padding="same", - use_bias=False, - kernel_initializer="he_normal", - input_shape=self.fpn.decoders[-1].output_shape[1:], - ), - layers.BatchNormalization(), - layers.Activation("relu"), - *conv_sequence(32, "relu", True, kernel_size=3, strides=1), - layers.Conv2DTranspose( - filters=num_classes, - kernel_size=2, - strides=2, - padding="same", - use_bias=True, - kernel_initializer="he_normal", - ), - ]) + self.classifier = Sequential( + [ + layers.Conv2DTranspose( + filters=32, + kernel_size=3, + strides=2, + padding="same", + use_bias=False, + kernel_initializer="he_normal", + input_shape=self.fpn.decoders[-1].output_shape[1:], + ), + layers.BatchNormalization(), + layers.Activation("relu"), + *conv_sequence(32, "relu", True, kernel_size=3, strides=1), + layers.Conv2DTranspose( + filters=num_classes, + kernel_size=2, + strides=2, + padding="same", + use_bias=True, + kernel_initializer="he_normal", + ), + ] + ) self.postprocessor = LinkNetPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh)