Skip to content

Commit

Permalink
feat: ✨ new TF Linknet Resnet checkpoints (mindee#1424)
Browse files Browse the repository at this point in the history
  • Loading branch information
odulcy-mindee authored Jan 8, 2024
1 parent a48b72a commit e5b3f46
Showing 1 changed file with 43 additions and 39 deletions.
82 changes: 43 additions & 39 deletions doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,41 @@
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"input_shape": (1024, 1024, 3),
"url": "https://doctr-static.mindee.com/models?id=v0.5.0/linknet_resnet18-a48e6ed3.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet18-b9ee56e6.zip&src=0",
},
"linknet_resnet34": {
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"input_shape": (1024, 1024, 3),
"url": "https://doctr-static.mindee.com/models?id=v0.6.0/linknet_resnet34-bf30afb1.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet34-51909c56.zip&src=0",
},
"linknet_resnet50": {
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"input_shape": (1024, 1024, 3),
"url": "https://doctr-static.mindee.com/models?id=v0.6.0/linknet_resnet50-cd299262.zip&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet50-ac9f3829.zip&src=0",
},
}


def decoder_block(in_chan: int, out_chan: int, stride: int, **kwargs: Any) -> Sequential:
"""Creates a LinkNet decoder block"""
return Sequential([
*conv_sequence(in_chan // 4, "relu", True, kernel_size=1, **kwargs),
layers.Conv2DTranspose(
filters=in_chan // 4,
kernel_size=3,
strides=stride,
padding="same",
use_bias=False,
kernel_initializer="he_normal",
),
layers.BatchNormalization(),
layers.Activation("relu"),
*conv_sequence(out_chan, "relu", True, kernel_size=1),
])
return Sequential(
[
*conv_sequence(in_chan // 4, "relu", True, kernel_size=1, **kwargs),
layers.Conv2DTranspose(
filters=in_chan // 4,
kernel_size=3,
strides=stride,
padding="same",
use_bias=False,
kernel_initializer="he_normal",
),
layers.BatchNormalization(),
layers.Activation("relu"),
*conv_sequence(out_chan, "relu", True, kernel_size=1),
]
)


class LinkNetFPN(Model, NestedObject):
Expand Down Expand Up @@ -129,28 +131,30 @@ def __init__(
self.fpn = LinkNetFPN(fpn_channels, [_shape[1:] for _shape in self.feat_extractor.output_shape])
self.fpn.build(self.feat_extractor.output_shape)

self.classifier = Sequential([
layers.Conv2DTranspose(
filters=32,
kernel_size=3,
strides=2,
padding="same",
use_bias=False,
kernel_initializer="he_normal",
input_shape=self.fpn.decoders[-1].output_shape[1:],
),
layers.BatchNormalization(),
layers.Activation("relu"),
*conv_sequence(32, "relu", True, kernel_size=3, strides=1),
layers.Conv2DTranspose(
filters=num_classes,
kernel_size=2,
strides=2,
padding="same",
use_bias=True,
kernel_initializer="he_normal",
),
])
self.classifier = Sequential(
[
layers.Conv2DTranspose(
filters=32,
kernel_size=3,
strides=2,
padding="same",
use_bias=False,
kernel_initializer="he_normal",
input_shape=self.fpn.decoders[-1].output_shape[1:],
),
layers.BatchNormalization(),
layers.Activation("relu"),
*conv_sequence(32, "relu", True, kernel_size=3, strides=1),
layers.Conv2DTranspose(
filters=num_classes,
kernel_size=2,
strides=2,
padding="same",
use_bias=True,
kernel_initializer="he_normal",
),
]
)

self.postprocessor = LinkNetPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh)

Expand Down

0 comments on commit e5b3f46

Please sign in to comment.