From f4162acb55a3ec6bdf45aef3ec08811f383a3992 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 3 Sep 2024 14:07:15 +0200 Subject: [PATCH] Reviewer feedback: rename function, fix typo --- skorch/net.py | 4 ++-- skorch/tests/test_net.py | 13 +++++++------ skorch/utils.py | 4 ++-- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/skorch/net.py b/skorch/net.py index 05e5ec88..6b7748be 100644 --- a/skorch/net.py +++ b/skorch/net.py @@ -46,7 +46,7 @@ from skorch.utils import to_device from skorch.utils import to_numpy from skorch.utils import to_tensor -from skorch.utils import get_torch_load_kwargs +from skorch.utils import get_default_torch_load_kwargs # pylint: disable=too-many-instance-attributes @@ -2652,7 +2652,7 @@ def _get_state_dict(f_name): else: torch_load_kwargs = self.torch_load_kwargs if torch_load_kwargs is None: - torch_load_kwargs = get_torch_load_kwargs() + torch_load_kwargs = get_default_torch_load_kwargs() def _get_state_dict(f_name): map_location = get_map_location(self.device) diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py index b4f056c7..768ef00b 100644 --- a/skorch/tests/test_net.py +++ b/skorch/tests/test_net.py @@ -3011,7 +3011,7 @@ def test_torch_load_kwargs_auto_weights_only_false_when_load_params( mock_torch_load = Mock(return_value=state_dict) monkeypatch.setattr(torch, "load", mock_torch_load) monkeypatch.setattr( - skorch.net, "get_torch_load_kwargs", lambda: expected_kwargs + skorch.net, "get_default_torch_load_kwargs", lambda: expected_kwargs ) net.load_params(f_params=tmp_path / 'params.pkl') @@ -3035,7 +3035,7 @@ def test_torch_load_kwargs_auto_weights_only_true_when_load_params( mock_torch_load = Mock(return_value=state_dict) monkeypatch.setattr(torch, "load", mock_torch_load) monkeypatch.setattr( - skorch.net, "get_torch_load_kwargs", lambda: expected_kwargs + skorch.net, "get_default_torch_load_kwargs", lambda: expected_kwargs ) net.load_params(f_params=tmp_path / 'params.pkl') @@ -3067,10 +3067,11 @@ def test_torch_load_kwargs_forwarded_to_torch_load( def test_torch_load_kwargs_auto_weights_false_pytorch_lt_2_6( self, net_cls, module_cls, monkeypatch, tmp_path ): - # Same test as test_torch_load_kwargs_auto_weights_only_false_when_load_params - # but without monkeypatching get_torch_load_kwargs. There is no corresponding - # test for >= 2.6.0 since it's not clear yet if the switch will be made in that - # version. + # Same test as + # test_torch_load_kwargs_auto_weights_only_false_when_load_params but + # without monkeypatching get_default_torch_load_kwargs. There is no + # corresponding test for >= 2.6.0 since it's not clear yet if the switch + # will be made in that version. # See discussion in 1063. from skorch._version import Version diff --git a/skorch/utils.py b/skorch/utils.py index b1320c2a..851936db 100644 --- a/skorch/utils.py +++ b/skorch/utils.py @@ -771,8 +771,8 @@ def _check_f_arguments(caller_name, **kwargs): return kwargs_module, kwargs_other -def get_torch_load_kwargs(): - """Returns the kwargs passed to torch.load the correspond to the current +def get_default_torch_load_kwargs(): + """Returns the kwargs passed to torch.load that correspond to the current torch version. The plan is to switch from weights_only=False to True in PyTorch version