Skip to content

Commit

Permalink
Reviewer feedback: rename function, fix typo
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminBossan committed Sep 3, 2024
1 parent ab9c536 commit f4162ac
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
4 changes: 2 additions & 2 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from skorch.utils import to_device
from skorch.utils import to_numpy
from skorch.utils import to_tensor
from skorch.utils import get_torch_load_kwargs
from skorch.utils import get_default_torch_load_kwargs


# pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -2652,7 +2652,7 @@ def _get_state_dict(f_name):
else:
torch_load_kwargs = self.torch_load_kwargs
if torch_load_kwargs is None:
torch_load_kwargs = get_torch_load_kwargs()
torch_load_kwargs = get_default_torch_load_kwargs()

def _get_state_dict(f_name):
map_location = get_map_location(self.device)
Expand Down
13 changes: 7 additions & 6 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -3011,7 +3011,7 @@ def test_torch_load_kwargs_auto_weights_only_false_when_load_params(
mock_torch_load = Mock(return_value=state_dict)
monkeypatch.setattr(torch, "load", mock_torch_load)
monkeypatch.setattr(
skorch.net, "get_torch_load_kwargs", lambda: expected_kwargs
skorch.net, "get_default_torch_load_kwargs", lambda: expected_kwargs
)

net.load_params(f_params=tmp_path / 'params.pkl')
Expand All @@ -3035,7 +3035,7 @@ def test_torch_load_kwargs_auto_weights_only_true_when_load_params(
mock_torch_load = Mock(return_value=state_dict)
monkeypatch.setattr(torch, "load", mock_torch_load)
monkeypatch.setattr(
skorch.net, "get_torch_load_kwargs", lambda: expected_kwargs
skorch.net, "get_default_torch_load_kwargs", lambda: expected_kwargs
)

net.load_params(f_params=tmp_path / 'params.pkl')
Expand Down Expand Up @@ -3067,10 +3067,11 @@ def test_torch_load_kwargs_forwarded_to_torch_load(
def test_torch_load_kwargs_auto_weights_false_pytorch_lt_2_6(
self, net_cls, module_cls, monkeypatch, tmp_path
):
# Same test as test_torch_load_kwargs_auto_weights_only_false_when_load_params
# but without monkeypatching get_torch_load_kwargs. There is no corresponding
# test for >= 2.6.0 since it's not clear yet if the switch will be made in that
# version.
# Same test as
# test_torch_load_kwargs_auto_weights_only_false_when_load_params but
# without monkeypatching get_default_torch_load_kwargs. There is no
# corresponding test for >= 2.6.0 since it's not clear yet if the switch
# will be made in that version.
# See discussion in 1063.
from skorch._version import Version

Expand Down
4 changes: 2 additions & 2 deletions skorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,8 +771,8 @@ def _check_f_arguments(caller_name, **kwargs):
return kwargs_module, kwargs_other


def get_torch_load_kwargs():
"""Returns the kwargs passed to torch.load the correspond to the current
def get_default_torch_load_kwargs():
"""Returns the kwargs passed to torch.load that correspond to the current
torch version.
The plan is to switch from weights_only=False to True in PyTorch version
Expand Down

0 comments on commit f4162ac

Please sign in to comment.