Skip to content

Commit

Permalink
[Chöre] update to new torchvision API as well (#1291)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixT2K authored Aug 28, 2023
1 parent 9e78906 commit 1f0169b
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 7 deletions.
4 changes: 2 additions & 2 deletions doctr/models/classification/mobilenet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def _mobilenet_v3(
kwargs.pop("classes")

if arch.startswith("mobilenet_v3_small"):
model = mobilenetv3.mobilenet_v3_small(**kwargs)
model = mobilenetv3.mobilenet_v3_small(**kwargs, weights=None)
else:
model = mobilenetv3.mobilenet_v3_large(**kwargs)
model = mobilenetv3.mobilenet_v3_large(**kwargs, weights=None)

# Rectangular strides
if isinstance(rect_strides, list):
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/classification/resnet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _tv_resnet(
kwargs.pop("classes")

# Build the model
model = arch_fn(**kwargs)
model = arch_fn(**kwargs, weights=None)
# Load pretrained parameters
if pretrained:
# The number of classes is not the same as the number of classes in the pretrained model =>
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/classification/vgg/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _vgg(
kwargs.pop("classes")

# Build the model
model = tv_vgg.__dict__[tv_arch](**kwargs)
model = tv_vgg.__dict__[tv_arch](**kwargs, weights=None)
# List the MaxPool2d
pool_idcs = [idx for idx, m in enumerate(model.features) if isinstance(m, nn.MaxPool2d)]
# Replace their kernel with rectangular ones
Expand Down
9 changes: 6 additions & 3 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,15 @@ def _dbnet(
ignore_keys: Optional[List[str]] = None,
**kwargs: Any,
) -> DBNet:
# Starting with Imagenet pretrained params introduces some NaNs in layer3 & layer4 of resnet50
pretrained_backbone = pretrained_backbone and not arch.split("_")[1].startswith("resnet")
pretrained_backbone = pretrained_backbone and not pretrained

# Feature extractor
backbone = backbone_fn(pretrained_backbone)
backbone = (
backbone_fn(pretrained_backbone)
if not arch.split("_")[1].startswith("resnet")
# Starting with Imagenet pretrained params introduces some NaNs in layer3 & layer4 of resnet50
else backbone_fn(weights=None) # type: ignore[call-arg]
)
if isinstance(backbone_submodule, str):
backbone = getattr(backbone, backbone_submodule)
feat_extractor = IntermediateLayerGetter(
Expand Down

0 comments on commit 1f0169b

Please sign in to comment.