Skip to content

Commit

Permalink
fix obj det train and suppress endless warning prints (#1267)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Jul 27, 2023
1 parent 0a47726 commit efe7ca0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 5 additions & 3 deletions doctr/models/obj_detection/faster_rcnn/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import Any, Dict

from torchvision.models.detection import FasterRCNN, faster_rcnn
from torchvision.models.detection import FasterRCNN, FasterRCNN_MobileNet_V3_Large_FPN_Weights, faster_rcnn

from ...utils import load_pretrained_params

Expand Down Expand Up @@ -37,7 +37,7 @@ def _fasterrcnn(arch: str, pretrained: bool, **kwargs: Any) -> FasterRCNN:

# Build the model
_kwargs.update(kwargs)
model = faster_rcnn.__dict__[arch](pretrained=False, pretrained_backbone=False, **_kwargs)
model = faster_rcnn.__dict__[arch](weights=None, weights_backbone=None, **_kwargs)
model.cfg = default_cfgs[arch]

if pretrained:
Expand All @@ -47,7 +47,9 @@ def _fasterrcnn(arch: str, pretrained: bool, **kwargs: Any) -> FasterRCNN:
# Filter keys
state_dict = {
k: v
for k, v in faster_rcnn.__dict__[arch](pretrained=True).state_dict().items()
for k, v in faster_rcnn.__dict__[arch](weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
.state_dict()
.items()
if not k.startswith("roi_heads.")
}

Expand Down
4 changes: 2 additions & 2 deletions doctr/transforms/modules/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
preserve_aspect_ratio: bool = False,
symmetric_pad: bool = False,
) -> None:
super().__init__(size, interpolation)
super().__init__(size, interpolation, antialias=True)
self.preserve_aspect_ratio = preserve_aspect_ratio
self.symmetric_pad = symmetric_pad

Expand Down Expand Up @@ -64,7 +64,7 @@ def forward(
tmp_size = (self.size, max(int(self.size / actual_ratio), 1))

# Scale image
img = F.resize(img, tmp_size, self.interpolation)
img = F.resize(img, tmp_size, self.interpolation, antialias=True)
raw_shape = img.shape[-2:]
if isinstance(self.size, (tuple, list)):
# Pad (inverted in pytorch)
Expand Down

0 comments on commit efe7ca0

Please sign in to comment.