Skip to content

Commit

Permalink
Non-functional clean ups related to lr schedulers
Browse files Browse the repository at this point in the history
While working on the fixes in this PR, I also cleaned up some lr
scheduler code. These clean ups are non-functional.

1. We imported CyclicLR as TorchCyclicLR. I'm not sure why but it is
   somehow related to very old PyTorch versions we no longer support, so I
   removed this.
2. Fixed some indentations for conditional checks to improve
   readability.
  • Loading branch information
BenjaminBossan committed Aug 16, 2023
1 parent ba5f153 commit 0cf1dc3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 24 deletions.
34 changes: 17 additions & 17 deletions skorch/callbacks/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,12 @@
import torch
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import CyclicLR
from torch.optim.lr_scheduler import ExponentialLR
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import MultiStepLR
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import StepLR

try:
from torch.optim.lr_scheduler import CyclicLR as TorchCyclicLR
except ImportError:
# Backward compatibility with torch >= 1.0 && < 1.1
TorchCyclicLR = None
from torch.optim.optimizer import Optimizer
from skorch.callbacks import Callback

Expand Down Expand Up @@ -152,7 +147,7 @@ def _step(self, net, lr_scheduler, score=None):
certain conditions.
For more info on the latter, see:
https://huggingface.co/docs/accelerate/v0.21.0/en/quicktour#mixed-precision-training
https://huggingface.co/docs/accelerate/quicktour#mixed-precision-training
"""
accelerator_maybe = getattr(net, 'accelerator', None)
Expand Down Expand Up @@ -187,28 +182,33 @@ def on_epoch_end(self, net, **kwargs):
self._step(net, self.lr_scheduler_, score=score)
# ReduceLROnPlateau does not expose the current lr so it can't be recorded
else:
if self.event_name is not None and hasattr(
self.lr_scheduler_, "get_last_lr"):
net.history.record(self.event_name,
self.lr_scheduler_.get_last_lr()[0])
if (
(self.event_name is not None)
and hasattr(self.lr_scheduler_, "get_last_lr")
):
net.history.record(self.event_name, self.lr_scheduler_.get_last_lr()[0])
self._step(net, self.lr_scheduler_)

def on_batch_end(self, net, training, **kwargs):
if not training or self.step_every != 'batch':
return
if self.event_name is not None and hasattr(
self.lr_scheduler_, "get_last_lr"):
net.history.record_batch(self.event_name,
self.lr_scheduler_.get_last_lr()[0])
if (
(self.event_name is not None)
and hasattr(self.lr_scheduler_, "get_last_lr")
):
net.history.record_batch(
self.event_name, self.lr_scheduler_.get_last_lr()[0])
self._step(net, self.lr_scheduler_)
self.batch_idx_ += 1

def _get_scheduler(self, net, policy, **scheduler_kwargs):
"""Return scheduler, based on indicated policy, with appropriate
parameters.
"""
if policy not in [ReduceLROnPlateau] and \
'last_epoch' not in scheduler_kwargs:
if (
(policy not in [ReduceLROnPlateau])
and ('last_epoch' not in scheduler_kwargs)
):
last_epoch = len(net.history) - 1
scheduler_kwargs['last_epoch'] = last_epoch

Expand Down
13 changes: 6 additions & 7 deletions skorch/tests/callbacks/test_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np
import pytest
import torch
from sklearn.base import clone
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
Expand All @@ -12,7 +11,7 @@
from torch.optim.lr_scheduler import MultiStepLR
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import StepLR
from torch.optim.lr_scheduler import CyclicLR as TorchCyclicLR
from torch.optim.lr_scheduler import CyclicLR

from skorch import NeuralNetClassifier
from skorch.callbacks.lr_scheduler import WarmRestartLR, LRScheduler
Expand All @@ -28,7 +27,7 @@ def test_simulate_lrs_epoch_step(self, policy):
expected = np.array([1.0, 1.0, 0.1, 0.1, 0.01, 0.01])
assert np.allclose(expected, lrs)

@pytest.mark.parametrize('policy', [TorchCyclicLR])
@pytest.mark.parametrize('policy', [CyclicLR])
def test_simulate_lrs_batch_step(self, policy):
lr_sch = LRScheduler(
policy, base_lr=1, max_lr=5, step_size_up=4, step_every='batch')
Expand Down Expand Up @@ -96,7 +95,7 @@ def test_lr_callback_steps_correctly(
assert lr_policy.lr_scheduler_.last_epoch == max_epochs

@pytest.mark.parametrize('policy, kwargs', [
(TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}),
(CyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}),
])
def test_lr_callback_batch_steps_correctly(
self,
Expand Down Expand Up @@ -125,7 +124,7 @@ def test_lr_callback_batch_steps_correctly(
assert lr_policy.batch_idx_ == expected

@pytest.mark.parametrize('policy, kwargs', [
(TorchCyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}),
(CyclicLR, {'base_lr': 1e-3, 'max_lr': 6e-3, 'step_every': 'batch'}),
])
def test_lr_callback_batch_steps_correctly_fallback(
self,
Expand Down Expand Up @@ -177,7 +176,7 @@ def test_lr_scheduler_cloneable(self):

def test_lr_scheduler_set_params(self, classifier_module, classifier_data):
scheduler = LRScheduler(
TorchCyclicLR, base_lr=123, max_lr=999, step_every='batch')
CyclicLR, base_lr=123, max_lr=999, step_every='batch')
net = NeuralNetClassifier(
classifier_module,
max_epochs=0,
Expand Down Expand Up @@ -212,7 +211,7 @@ def test_lr_scheduler_record_batch_step(self, classifier_module, classifier_data
batch_size = 128

scheduler = LRScheduler(
TorchCyclicLR,
CyclicLR,
base_lr=1,
max_lr=5,
step_size_up=4,
Expand Down

0 comments on commit 0cf1dc3

Please sign in to comment.