From d3af255dffeb8b6eb3b5a7e7a373c2acf11b3f26 Mon Sep 17 00:00:00 2001 From: DimitriTimoz Date: Thu, 7 Nov 2024 13:25:05 +0100 Subject: [PATCH 1/4] Add LRScheduler --- candle-nn/examples/basic_optimizer.rs | 10 +++++++- candle-nn/src/lib.rs | 2 ++ candle-nn/src/scheduler.rs | 37 +++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 1 deletion(-) create mode 100644 candle-nn/src/scheduler.rs diff --git a/candle-nn/examples/basic_optimizer.rs b/candle-nn/examples/basic_optimizer.rs index 810f7a7a75..d16da3b2d5 100644 --- a/candle-nn/examples/basic_optimizer.rs +++ b/candle-nn/examples/basic_optimizer.rs @@ -5,7 +5,10 @@ extern crate intel_mkl_src; extern crate accelerate_src; use candle::{DType, Device, Result, Tensor}; -use candle_nn::{linear, AdamW, Linear, Module, Optimizer, ParamsAdamW, VarBuilder, VarMap}; +use candle_nn::{ + linear, AdamW, FnLRScheduler, LRScheduler, Linear, Module, Optimizer, ParamsAdamW, VarBuilder, + VarMap, +}; fn gen_data() -> Result<(Tensor, Tensor)> { // Generate some sample linear data. @@ -29,7 +32,12 @@ fn main() -> Result<()> { ..Default::default() }; let mut opt = AdamW::new(varmap.all_vars(), params)?; + let mut scheduler = FnLRScheduler::::new(Box::new(|step| { + Ok(0.2 * 0.9f64.powi((step as f64 / 1000f64).floor() as i32)) + })); + for step in 0..10000 { + opt.set_learning_rate(scheduler.step(step)?); let ys = model.forward(&sample_xs)?; let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; opt.backward_step(&loss)?; diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index fcac58308c..8880f79420 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -12,6 +12,7 @@ pub mod linear; pub mod loss; pub mod ops; pub mod optim; +pub mod scheduler; pub mod rnn; pub mod rotary_emb; pub mod sequential; @@ -33,6 +34,7 @@ pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; pub use linear::{linear, linear_b, linear_no_bias, Linear}; pub use ops::Dropout; pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD}; +pub use scheduler::{FnLRScheduler, LRScheduler}; pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN}; pub use sequential::{seq, Sequential}; pub use var_builder::VarBuilder; diff --git a/candle-nn/src/scheduler.rs b/candle-nn/src/scheduler.rs new file mode 100644 index 0000000000..6e7bda4c10 --- /dev/null +++ b/candle-nn/src/scheduler.rs @@ -0,0 +1,37 @@ +use candle::Result; + +/// The interface LR Schedulers should implement. +pub trait LRScheduler { + /// Step the scheduler and return the new learning rate. + fn step(&mut self, params: T) -> Result; + + /// Get the current learning rate. + fn get_lr(&self) -> f64; +} + +/// A learning rate scheduler that uses a function to determine the learning rate. +/// The function should take a parameter of type `T` and return a `f64`. +pub struct FnLRScheduler { + pub func: Box Result>, + pub lr: f64, +} + +impl FnLRScheduler { + pub fn new(func: Box Result>) -> Self { + Self { + func, + lr: 0.0, + } + } +} + +impl LRScheduler for FnLRScheduler { + fn step(&mut self, params: T) -> Result { + self.lr = (self.func)(params)?; + Ok(self.lr) + } + + fn get_lr(&self) -> f64 { + self.lr + } +} From 0638745008ef702ddeb855b77be8c78b415a46e9 Mon Sep 17 00:00:00 2001 From: DimitriTimoz Date: Thu, 7 Nov 2024 13:42:57 +0100 Subject: [PATCH 2/4] Implement StepLR --- candle-nn/src/scheduler.rs | 39 ++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/candle-nn/src/scheduler.rs b/candle-nn/src/scheduler.rs index 6e7bda4c10..795a0b468f 100644 --- a/candle-nn/src/scheduler.rs +++ b/candle-nn/src/scheduler.rs @@ -18,10 +18,7 @@ pub struct FnLRScheduler { impl FnLRScheduler { pub fn new(func: Box Result>) -> Self { - Self { - func, - lr: 0.0, - } + Self { func, lr: 0.0 } } } @@ -35,3 +32,37 @@ impl LRScheduler for FnLRScheduler { self.lr } } + +/// Decays the learning rate of each parameter group by gamma every step_size epochs. +// https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html#torch.optim.lr_scheduler.StepLR +pub struct StepLR { + step_size: usize, + last_epoch: usize, + gamma: f64, + lr: f64, +} + +impl StepLR { + pub fn new(step_size: usize, gamma: f64, lr: f64) -> Self { + Self { + step_size, + last_epoch: 0, + gamma, + lr, + } + } +} + +impl LRScheduler<()> for StepLR { + fn step(&mut self, _params: ()) -> Result { + self.last_epoch += 1; + if self.last_epoch % self.step_size == 0 { + self.lr *= self.gamma; + } + Ok(self.lr) + } + + fn get_lr(&self) -> f64 { + self.lr + } +} From b06d8bb62f362bc3c6b16d8474891265ea89aa81 Mon Sep 17 00:00:00 2001 From: DimitriTimoz Date: Thu, 7 Nov 2024 13:54:09 +0100 Subject: [PATCH 3/4] MultiStepLR --- candle-nn/src/scheduler.rs | 41 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/candle-nn/src/scheduler.rs b/candle-nn/src/scheduler.rs index 795a0b468f..361a2bf673 100644 --- a/candle-nn/src/scheduler.rs +++ b/candle-nn/src/scheduler.rs @@ -66,3 +66,44 @@ impl LRScheduler<()> for StepLR { self.lr } } +/// Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones. +// https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.MultiStepLR.html#torch.optim.lr_scheduler.MultiStepLR +pub struct MultiStepLR { + millstones: Vec, + gamma: f64, + last_epoch: usize, + lr: f64, +} + +impl MultiStepLR { + pub fn new(millstones: Vec, gamma: f64, lr: f64) -> Result { + // Ensure millstones are sorted. + if !millstones.is_sorted() { + candle::bail!("millstones should be sorted") + } + + Ok(Self { + millstones, + gamma, + last_epoch: 0, + lr, + }) + } +} + +impl LRScheduler<()> for MultiStepLR { + fn step(&mut self, _params: ()) -> Result { + self.last_epoch += 1; + if let Some(step) = self.millstones.first() { + if self.last_epoch == *step { + self.millstones.remove(0); + self.lr *= self.gamma; + } + } + Ok(self.lr) + } + + fn get_lr(&self) -> f64 { + self.lr + } +} From 874a6a482961841ac58065ef2a26629d22f99478 Mon Sep 17 00:00:00 2001 From: DimitriTimoz Date: Thu, 7 Nov 2024 14:03:41 +0100 Subject: [PATCH 4/4] Add CosineAnnealingLR --- candle-nn/src/scheduler.rs | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/candle-nn/src/scheduler.rs b/candle-nn/src/scheduler.rs index 361a2bf673..a46ebb818d 100644 --- a/candle-nn/src/scheduler.rs +++ b/candle-nn/src/scheduler.rs @@ -107,3 +107,39 @@ impl LRScheduler<()> for MultiStepLR { self.lr } } + +/// Set the learning rate of each parameter group using a cosine annealing schedule. +//https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html#torch.optim.lr_scheduler.CosineAnnealingLR +pub struct CosineAnnealingLR { + t_max: usize, + last_epoch: usize, + eta_min: f64, + lr: f64, +} + +impl CosineAnnealingLR { + pub fn new(t_max: usize, eta_min: f64, lr: f64) -> Self { + Self { + t_max, + last_epoch: 0, + eta_min, + lr, + } + } +} + +impl LRScheduler<()> for CosineAnnealingLR { + fn step(&mut self, _params: ()) -> Result { + self.lr = self.eta_min + + 0.5 + * (self.lr - self.eta_min) + * (1. + ((self.last_epoch as f64 / self.t_max as f64) * std::f64::consts::PI)).cos(); + self.last_epoch += 1; + self.last_epoch = self.last_epoch.min(self.t_max); + Ok(self.lr) + } + + fn get_lr(&self) -> f64 { + self.lr + } +}