diff --git a/doctr/models/classification/mobilenet/pytorch.py b/doctr/models/classification/mobilenet/pytorch.py index e641763b1f..1aa9211fc5 100644 --- a/doctr/models/classification/mobilenet/pytorch.py +++ b/doctr/models/classification/mobilenet/pytorch.py @@ -77,9 +77,9 @@ def _mobilenet_v3( kwargs.pop("classes") if arch.startswith("mobilenet_v3_small"): - model = mobilenetv3.mobilenet_v3_small(**kwargs) + model = mobilenetv3.mobilenet_v3_small(**kwargs, weights=None) else: - model = mobilenetv3.mobilenet_v3_large(**kwargs) + model = mobilenetv3.mobilenet_v3_large(**kwargs, weights=None) # Rectangular strides if isinstance(rect_strides, list): diff --git a/doctr/models/classification/resnet/pytorch.py b/doctr/models/classification/resnet/pytorch.py index 061dd0079e..3ed61140d8 100644 --- a/doctr/models/classification/resnet/pytorch.py +++ b/doctr/models/classification/resnet/pytorch.py @@ -200,7 +200,7 @@ def _tv_resnet( kwargs.pop("classes") # Build the model - model = arch_fn(**kwargs) + model = arch_fn(**kwargs, weights=None) # Load pretrained parameters if pretrained: # The number of classes is not the same as the number of classes in the pretrained model => diff --git a/doctr/models/classification/vgg/pytorch.py b/doctr/models/classification/vgg/pytorch.py index 28afdd47da..0cbfe33f50 100644 --- a/doctr/models/classification/vgg/pytorch.py +++ b/doctr/models/classification/vgg/pytorch.py @@ -44,7 +44,7 @@ def _vgg( kwargs.pop("classes") # Build the model - model = tv_vgg.__dict__[tv_arch](**kwargs) + model = tv_vgg.__dict__[tv_arch](**kwargs, weights=None) # List the MaxPool2d pool_idcs = [idx for idx, m in enumerate(model.features) if isinstance(m, nn.MaxPool2d)] # Replace their kernel with rectangular ones diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 919de2044a..df47f34d08 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -292,12 +292,15 @@ def _dbnet( ignore_keys: Optional[List[str]] = None, **kwargs: Any, ) -> DBNet: - # Starting with Imagenet pretrained params introduces some NaNs in layer3 & layer4 of resnet50 - pretrained_backbone = pretrained_backbone and not arch.split("_")[1].startswith("resnet") pretrained_backbone = pretrained_backbone and not pretrained # Feature extractor - backbone = backbone_fn(pretrained_backbone) + backbone = ( + backbone_fn(pretrained_backbone) + if not arch.split("_")[1].startswith("resnet") + # Starting with Imagenet pretrained params introduces some NaNs in layer3 & layer4 of resnet50 + else backbone_fn(weights=None) # type: ignore[call-arg] + ) if isinstance(backbone_submodule, str): backbone = getattr(backbone, backbone_submodule) feat_extractor = IntermediateLayerGetter(