From 0cf1dc30eab8a840898626dcace5e3652e25a85b Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 16 Aug 2023 14:30:43 +0200 Subject: [PATCH] Non-functional clean ups related to lr schedulers While working on the fixes in this PR, I also cleaned up some lr scheduler code. These clean ups are non-functional. 1. We imported CyclicLR as TorchCyclicLR. I'm not sure why but it is somehow related to very old PyTorch versions we no longer support, so I removed this. 2. Fixed some indentations for conditional checks to improve readability. --- skorch/callbacks/lr_scheduler.py | 34 ++++++++++----------- skorch/tests/callbacks/test_lr_scheduler.py | 13 ++++---- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/skorch/callbacks/lr_scheduler.py b/skorch/callbacks/lr_scheduler.py index 64934b598..17acb1bff 100644 --- a/skorch/callbacks/lr_scheduler.py +++ b/skorch/callbacks/lr_scheduler.py @@ -9,17 +9,12 @@ import torch from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import CyclicLR from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import MultiStepLR from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import StepLR - -try: - from torch.optim.lr_scheduler import CyclicLR as TorchCyclicLR -except ImportError: - # Backward compatibility with torch >= 1.0 && < 1.1 - TorchCyclicLR = None from torch.optim.optimizer import Optimizer from skorch.callbacks import Callback @@ -152,7 +147,7 @@ def _step(self, net, lr_scheduler, score=None): certain conditions. For more info on the latter, see: - https://huggingface.co/docs/accelerate/v0.21.0/en/quicktour#mixed-precision-training + https://huggingface.co/docs/accelerate/quicktour#mixed-precision-training """ accelerator_maybe = getattr(net, 'accelerator', None) @@ -187,19 +182,22 @@ def on_epoch_end(self, net, **kwargs): self._step(net, self.lr_scheduler_, score=score) # ReduceLROnPlateau does not expose the current lr so it can't be recorded else: - if self.event_name is not None and hasattr( - self.lr_scheduler_, "get_last_lr"): - net.history.record(self.event_name, - self.lr_scheduler_.get_last_lr()[0]) + if ( + (self.event_name is not None) + and hasattr(self.lr_scheduler_, "get_last_lr") + ): + net.history.record(self.event_name, self.lr_scheduler_.get_last_lr()[0]) self._step(net, self.lr_scheduler_) def on_batch_end(self, net, training, **kwargs): if not training or self.step_every != 'batch': return - if self.event_name is not None and hasattr( - self.lr_scheduler_, "get_last_lr"): - net.history.record_batch(self.event_name, - self.lr_scheduler_.get_last_lr()[0]) + if ( + (self.event_name is not None) + and hasattr(self.lr_scheduler_, "get_last_lr") + ): + net.history.record_batch( + self.event_name, self.lr_scheduler_.get_last_lr()[0]) self._step(net, self.lr_scheduler_) self.batch_idx_ += 1 @@ -207,8 +205,10 @@ def _get_scheduler(self, net, policy, **scheduler_kwargs): """Return scheduler, based on indicated policy, with appropriate parameters. """ - if policy not in [ReduceLROnPlateau] and \ - 'last_epoch' not in scheduler_kwargs: + if ( + (policy not in [ReduceLROnPlateau]) + and ('last_epoch' not in scheduler_kwargs) + ): last_epoch = len(net.history) - 1 scheduler_kwargs['last_epoch'] = last_epoch diff --git a/skorch/tests/callbacks/test_lr_scheduler.py b/skorch/tests/callbacks/test_lr_scheduler.py index a4ff76216..44c96f248 100644 --- a/skorch/tests/callbacks/test_lr_scheduler.py +++ b/skorch/tests/callbacks/test_lr_scheduler.py @@ -3,7 +3,6 @@ import numpy as np import pytest -import torch from sklearn.base import clone from torch.optim import SGD from torch.optim.lr_scheduler import CosineAnnealingLR @@ -12,7 +11,7 @@ from torch.optim.lr_scheduler import MultiStepLR from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import StepLR -from torch.optim.lr_scheduler import CyclicLR as TorchCyclicLR +from torch.optim.lr_scheduler import CyclicLR from skorch import NeuralNetClassifier from skorch.callbacks.lr_scheduler import WarmRestartLR, LRScheduler @@ -28,7 +27,7 @@ def test_simulate_lrs_epoch_step(self, policy): expected = np.array([1.0, 1.0, 0.1, 0.1, 0.01, 0.01]) assert np.allclose(expected, lrs) - @pytest.mark.parametrize('policy', [TorchCyclicLR]) + @pytest.mark.parametrize('policy', [CyclicLR]) def test_simulate_lrs_batch_step(self, policy): lr_sch = LRScheduler( policy, base_lr=1, max_lr=5, step_size_up=4, step_every='batch') @@ -96,7 +95,7 @@ def test_lr_callback_steps_correctly( assert lr_policy.lr_scheduler_.last_epoch == max_epochs @pytest.mark.parametrize('policy, kwargs', [ - (TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}), + (CyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}), ]) def test_lr_callback_batch_steps_correctly( self, @@ -125,7 +124,7 @@ def test_lr_callback_batch_steps_correctly( assert lr_policy.batch_idx_ == expected @pytest.mark.parametrize('policy, kwargs', [ - (TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}), + (CyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}), ]) def test_lr_callback_batch_steps_correctly_fallback( self, @@ -177,7 +176,7 @@ def test_lr_scheduler_cloneable(self): def test_lr_scheduler_set_params(self, classifier_module, classifier_data): scheduler = LRScheduler( - TorchCyclicLR, base_lr=123, max_lr=999, step_every='batch') + CyclicLR, base_lr=123, max_lr=999, step_every='batch') net = NeuralNetClassifier( classifier_module, max_epochs=0, @@ -212,7 +211,7 @@ def test_lr_scheduler_record_batch_step(self, classifier_module, classifier_data batch_size = 128 scheduler = LRScheduler( - TorchCyclicLR, + CyclicLR, base_lr=1, max_lr=5, step_size_up=4,