From b6276b71aad3dff65b76dae4c75e521cd47b5577 Mon Sep 17 00:00:00 2001 From: kigi Date: Sat, 5 Oct 2024 13:06:24 +0800 Subject: [PATCH 1/6] Implements bidirectional RNN --- candle-nn/src/rnn.rs | 87 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 80 insertions(+), 7 deletions(-) diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index 798db6ac4d..202fe76e30 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -6,6 +6,9 @@ use candle::{DType, Device, IndexOp, Result, Tensor}; pub trait RNN { type State: Clone; + /// Returns the direction of the RNN. + fn direction(&self) -> Direction; + /// A zero state from which the recurrent network is usually initialized. fn zero_state(&self, batch_dim: usize) -> Result; @@ -31,7 +34,12 @@ pub trait RNN { let (_b_size, seq_len, _features) = input.dims3()?; let mut output = Vec::with_capacity(seq_len); for seq_index in 0..seq_len { - let input = input.i((.., seq_index, ..))?.contiguous()?; + let index = if self.direction() == Direction::Forward { + seq_index + } else { + seq_len - seq_index - 1 + }; + let input = input.i((.., index, ..))?.contiguous()?; let state = if seq_index == 0 { self.step(&input, init_state)? } else { @@ -39,11 +47,21 @@ pub trait RNN { }; output.push(state); } + if self.direction() == Direction::Backward { + output.reverse(); + } Ok(output) } /// Converts a sequence of state to a tensor. fn states_to_tensor(&self, states: &[Self::State]) -> Result; + + /// Combines forward and backward states to a tensor. + fn combine_states_to_tensor( + &self, + forward_states: &[Self::State], + backward_states: &[Self::State], + ) -> Result; } /// The state for a LSTM network, this contains two tensors. @@ -70,7 +88,7 @@ impl LSTMState { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq)] pub enum Direction { Forward, Backward, @@ -198,6 +216,10 @@ pub fn lstm( impl RNN for LSTM { type State = LSTMState; + fn direction(&self) -> Direction { + self.config.direction + } + fn zero_state(&self, batch_dim: usize) -> Result { let zeros = Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?.contiguous()?; @@ -236,6 +258,22 @@ impl RNN for LSTM { let states = states.iter().map(|s| s.h.clone()).collect::>(); Tensor::stack(&states, 1) } + + fn combine_states_to_tensor( + &self, + forward_states: &[Self::State], + backward_states: &[Self::State], + ) -> Result { + let combine_states = forward_states + .iter() + .zip(backward_states.iter()) + .collect::>(); + let mut states = Vec::with_capacity(combine_states.len()); + for (f, b) in combine_states { + states.push(Tensor::cat(&[&f.h, &b.h], 1)?); + } + Tensor::stack(&states, 1) + } } /// The state for a GRU network, this contains a single tensor. @@ -259,6 +297,7 @@ pub struct GRUConfig { pub w_hh_init: super::Init, pub b_ih_init: Option, pub b_hh_init: Option, + pub direction: Direction, } impl Default for GRUConfig { @@ -268,6 +307,7 @@ impl Default for GRUConfig { w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM, b_ih_init: Some(super::Init::Const(0.)), b_hh_init: Some(super::Init::Const(0.)), + direction: Direction::Forward, } } } @@ -279,6 +319,7 @@ impl GRUConfig { w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM, b_ih_init: None, b_hh_init: None, + direction: Direction::Forward, } } } @@ -307,22 +348,34 @@ impl GRU { config: GRUConfig, vb: crate::VarBuilder, ) -> Result { + let direction_str = match config.direction { + Direction::Forward => "", + Direction::Backward => "_reverse", + }; let w_ih = vb.get_with_hints( (3 * hidden_dim, in_dim), - "weight_ih_l0", // Only a single layer is supported. + &format!("weight_ih_l0{direction_str}"), // Only a single layer is supported. config.w_ih_init, )?; let w_hh = vb.get_with_hints( (3 * hidden_dim, hidden_dim), - "weight_hh_l0", // Only a single layer is supported. + &format!("weight_hh_l0{direction_str}"), // Only a single layer is supported. config.w_hh_init, )?; let b_ih = match config.b_ih_init { - Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?), + Some(init) => Some(vb.get_with_hints( + 3 * hidden_dim, + &format!("bias_ih_l0{direction_str}"), + init, + )?), None => None, }; let b_hh = match config.b_hh_init { - Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?), + Some(init) => Some(vb.get_with_hints( + 3 * hidden_dim, + &format!("bias_hh_l0{direction_str}"), + init, + )?), None => None, }; Ok(Self { @@ -354,6 +407,10 @@ pub fn gru( impl RNN for GRU { type State = GRUState; + fn direction(&self) -> Direction { + self.config.direction + } + fn zero_state(&self, batch_dim: usize) -> Result { let h = Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?.contiguous()?; @@ -383,6 +440,22 @@ impl RNN for GRU { fn states_to_tensor(&self, states: &[Self::State]) -> Result { let states = states.iter().map(|s| s.h.clone()).collect::>(); - Tensor::cat(&states, 1) + Tensor::stack(&states, 1) + } + + fn combine_states_to_tensor( + &self, + forward_states: &[Self::State], + backward_states: &[Self::State], + ) -> Result { + let combine_states = forward_states + .iter() + .zip(backward_states.iter()) + .collect::>(); + let mut states = Vec::with_capacity(combine_states.len()); + for (f, b) in combine_states { + states.push(Tensor::cat(&[&f.h, &b.h], 1)?); + } + Tensor::stack(&states, 1) } } From 2a3bec7151743215ee244763a09fe0d620e5fe51 Mon Sep 17 00:00:00 2001 From: kigi Date: Tue, 8 Oct 2024 21:19:32 +0800 Subject: [PATCH 2/6] GRU support multi-layer --- candle-nn/src/rnn.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index 202fe76e30..500ca08300 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -297,6 +297,7 @@ pub struct GRUConfig { pub w_hh_init: super::Init, pub b_ih_init: Option, pub b_hh_init: Option, + pub layer_idx: usize, pub direction: Direction, } @@ -307,6 +308,7 @@ impl Default for GRUConfig { w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM, b_ih_init: Some(super::Init::Const(0.)), b_hh_init: Some(super::Init::Const(0.)), + layer_idx: 0, direction: Direction::Forward, } } @@ -319,6 +321,7 @@ impl GRUConfig { w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM, b_ih_init: None, b_hh_init: None, + layer_idx: 0, direction: Direction::Forward, } } @@ -348,24 +351,25 @@ impl GRU { config: GRUConfig, vb: crate::VarBuilder, ) -> Result { + let layer_idx = config.layer_idx; let direction_str = match config.direction { Direction::Forward => "", Direction::Backward => "_reverse", }; let w_ih = vb.get_with_hints( (3 * hidden_dim, in_dim), - &format!("weight_ih_l0{direction_str}"), // Only a single layer is supported. + &format!("weight_ih_l{layer_idx}{direction_str}"), config.w_ih_init, )?; let w_hh = vb.get_with_hints( (3 * hidden_dim, hidden_dim), - &format!("weight_hh_l0{direction_str}"), // Only a single layer is supported. + &format!("weight_hh_l{layer_idx}{direction_str}"), config.w_hh_init, )?; let b_ih = match config.b_ih_init { Some(init) => Some(vb.get_with_hints( 3 * hidden_dim, - &format!("bias_ih_l0{direction_str}"), + &format!("bias_ih_l{layer_idx}{direction_str}"), init, )?), None => None, @@ -373,7 +377,7 @@ impl GRU { let b_hh = match config.b_hh_init { Some(init) => Some(vb.get_with_hints( 3 * hidden_dim, - &format!("bias_hh_l0{direction_str}"), + &format!("bias_hh_l{layer_idx}{direction_str}"), init, )?), None => None, From 0e3c8410c913d84c76afda67cdbd8318b987dd8d Mon Sep 17 00:00:00 2001 From: kigi Date: Tue, 8 Oct 2024 21:22:04 +0800 Subject: [PATCH 3/6] Example about RNN multi-layer and bidirection --- candle-examples/examples/rnn/README.md | 15 ++ candle-examples/examples/rnn/main.rs | 226 +++++++++++++++++++++++++ 2 files changed, 241 insertions(+) create mode 100644 candle-examples/examples/rnn/README.md create mode 100644 candle-examples/examples/rnn/main.rs diff --git a/candle-examples/examples/rnn/README.md b/candle-examples/examples/rnn/README.md new file mode 100644 index 0000000000..3f508bf005 --- /dev/null +++ b/candle-examples/examples/rnn/README.md @@ -0,0 +1,15 @@ +# candle-rnn: Recurrent Neural Network + +This example demonstrates how to use the `candle_nn::rnn` crate to run LSTM and GRU, including bidirection and multi-layer. + +## Running the example + +```bash +$ cargo run --example rnn --release +``` + +Choose LSTM or GRU via the `--model` argument, number of layers via `--layer`, and to enable bidirectional via `--bidirection`. + +```bash +$ cargo run --example rnn --release -- --model lstm --layers 3 --bidirection +``` diff --git a/candle-examples/examples/rnn/main.rs b/candle-examples/examples/rnn/main.rs new file mode 100644 index 0000000000..7998ce893d --- /dev/null +++ b/candle-examples/examples/rnn/main.rs @@ -0,0 +1,226 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Result; +use candle::{DType, Tensor}; +use candle_nn::{rnn, LSTMConfig, RNN}; +use clap::Parser; + +#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)] +enum WhichModel { + #[value(name = "lstm")] + LSTM, + #[value(name = "gru")] + GRU, +} + +#[derive(Debug, Parser)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(long)] + cpu: bool, + + #[arg(long, default_value_t = 10)] + input_dim: usize, + + #[arg(long, default_value_t = 20)] + hidden_dim: usize, + + #[arg(long, default_value_t = 1)] + layers: usize, + + #[arg(long)] + bidirection: bool, + + #[arg(long, default_value_t = 5)] + batch_size: usize, + + #[arg(long, default_value_t = 3)] + seq_len: usize, + + #[arg(long, default_value = "lstm")] + model: WhichModel, +} + +fn lstm_config(layer_idx: usize, direction: rnn::Direction) -> LSTMConfig { + let mut config = LSTMConfig::default(); + config.layer_idx = layer_idx; + config.direction = direction; + config +} + +fn gru_config(layer_idx: usize, direction: rnn::Direction) -> rnn::GRUConfig { + let mut config = rnn::GRUConfig::default(); + config.layer_idx = layer_idx; + config.direction = direction; + config +} + +fn run_lstm(args: Args) -> Result { + let device = candle_examples::device(args.cpu)?; + let map = candle_nn::VarMap::new(); + let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device); + + let mut layers = Vec::with_capacity(args.layers); + + for layer_idx in 0..args.layers { + let input_dim = if layer_idx == 0 { + args.input_dim + } else { + args.hidden_dim + }; + let config = lstm_config(layer_idx, rnn::Direction::Forward); + let lstm = candle_nn::lstm(input_dim, args.hidden_dim, config, vb.clone())?; + layers.push(lstm); + } + + let mut input = Tensor::randn( + 0.0_f32, + 1.0, + (args.batch_size, args.seq_len, args.input_dim), + &device, + )?; + + for layer in &layers { + let states = layer.seq(&input)?; + input = layer.states_to_tensor(&states)?; + } + + Ok(input) +} + +fn run_gru(args: Args) -> Result { + let device = candle_examples::device(args.cpu)?; + let map = candle_nn::VarMap::new(); + let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device); + + let mut layers = Vec::with_capacity(args.layers); + + for layer_idx in 0..args.layers { + let input_dim = if layer_idx == 0 { + args.input_dim + } else { + args.hidden_dim + }; + let config = gru_config(layer_idx, rnn::Direction::Forward); + let gru = candle_nn::gru(input_dim, args.hidden_dim, config, vb.clone())?; + layers.push(gru); + } + + let mut input = Tensor::randn( + 0.0_f32, + 1.0, + (args.batch_size, args.seq_len, args.input_dim), + &device, + )?; + + for layer in &layers { + let states = layer.seq(&input)?; + input = layer.states_to_tensor(&states)?; + } + + Ok(input) +} + +fn run_bidirectional_lstm(args: Args) -> Result { + let device = candle_examples::device(args.cpu)?; + let map = candle_nn::VarMap::new(); + let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device); + + let mut layers = Vec::with_capacity(args.layers); + + for layer_idx in 0..args.layers { + let input_dim = if layer_idx == 0 { + args.input_dim + } else { + args.hidden_dim * 2 + }; + + let forward_config = lstm_config(layer_idx, rnn::Direction::Forward); + let forward = candle_nn::lstm(input_dim, args.hidden_dim, forward_config, vb.clone())?; + + let backward_config = lstm_config(layer_idx, rnn::Direction::Backward); + let backward = candle_nn::lstm(input_dim, args.hidden_dim, backward_config, vb.clone())?; + + layers.push((forward, backward)); + } + + let mut input = Tensor::randn( + 0.0_f32, + 1.0, + (args.batch_size, args.seq_len, args.input_dim), + &device, + )?; + + for (forward, backward) in &layers { + let forward_states = forward.seq(&input)?; + let backward_states = backward.seq(&input)?; + input = forward.combine_states_to_tensor(&forward_states, &backward_states)?; + } + Ok(input) +} + +fn run_bidirectional_gru(args: Args) -> Result { + let device = candle_examples::device(args.cpu)?; + let map = candle_nn::VarMap::new(); + let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device); + + let mut layers = Vec::with_capacity(args.layers); + for layer_idx in 0..args.layers { + let input_dim = if layer_idx == 0 { + args.input_dim + } else { + args.hidden_dim * 2 + }; + + let forward_config = gru_config(layer_idx, rnn::Direction::Forward); + let forward = candle_nn::gru(input_dim, args.hidden_dim, forward_config, vb.clone())?; + + let backward_config = gru_config(layer_idx, rnn::Direction::Backward); + let backward = candle_nn::gru(input_dim, args.hidden_dim, backward_config, vb.clone())?; + + layers.push((forward, backward)); + } + + let mut input = Tensor::randn( + 0.0_f32, + 1.0, + (args.batch_size, args.seq_len, args.input_dim), + &device, + )?; + + for (forward, backward) in &layers { + let forward_states = forward.seq(&input)?; + let backward_states = backward.seq(&input)?; + input = forward.combine_states_to_tensor(&forward_states, &backward_states)?; + } + + Ok(input) +} + +fn main() -> Result<()> { + let args = Args::parse(); + let runs = if args.bidirection { 2 } else { 1 }; + let batch_size = args.batch_size; + let seq_len = args.seq_len; + let hidden_dim = args.hidden_dim; + + println!( + "Running {:?} bidirection: {} layers: {}", + args.model, args.bidirection, args.layers + ); + + let output = match (args.model, args.bidirection) { + (WhichModel::LSTM, false) => run_lstm(args), + (WhichModel::GRU, false) => run_gru(args), + (WhichModel::LSTM, true) => run_bidirectional_lstm(args), + (WhichModel::GRU, true) => run_bidirectional_gru(args), + }?; + + assert_eq!(output.dims3()?, (batch_size, seq_len, hidden_dim * runs)); + + Ok(()) +} From d16a45e89308b45261c8f9e2a15d9090adfb0f60 Mon Sep 17 00:00:00 2001 From: kigi Date: Fri, 11 Oct 2024 16:37:38 +0800 Subject: [PATCH 4/6] refactor function name --- candle-nn/src/rnn.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index 500ca08300..e3f3e1bacb 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -57,7 +57,7 @@ pub trait RNN { fn states_to_tensor(&self, states: &[Self::State]) -> Result; /// Combines forward and backward states to a tensor. - fn combine_states_to_tensor( + fn bidirectional_states_to_tensor( &self, forward_states: &[Self::State], backward_states: &[Self::State], @@ -259,7 +259,7 @@ impl RNN for LSTM { Tensor::stack(&states, 1) } - fn combine_states_to_tensor( + fn bidirectional_states_to_tensor( &self, forward_states: &[Self::State], backward_states: &[Self::State], @@ -447,7 +447,7 @@ impl RNN for GRU { Tensor::stack(&states, 1) } - fn combine_states_to_tensor( + fn bidirectional_states_to_tensor( &self, forward_states: &[Self::State], backward_states: &[Self::State], From 54d881943a62f0ecdc785e74531cd2e5c42586a4 Mon Sep 17 00:00:00 2001 From: kigi Date: Fri, 11 Oct 2024 16:41:32 +0800 Subject: [PATCH 5/6] add test for example to confirm that the results are similar to pytorch --- candle-examples/examples/rnn/README.md | 152 +++++++++++ candle-examples/examples/rnn/main.rs | 341 +++++++++++++++++++------ 2 files changed, 410 insertions(+), 83 deletions(-) diff --git a/candle-examples/examples/rnn/README.md b/candle-examples/examples/rnn/README.md index 3f508bf005..567c9a96b2 100644 --- a/candle-examples/examples/rnn/README.md +++ b/candle-examples/examples/rnn/README.md @@ -13,3 +13,155 @@ Choose LSTM or GRU via the `--model` argument, number of layers via `--layer`, a ```bash $ cargo run --example rnn --release -- --model lstm --layers 3 --bidirection ``` + +## Running the example test + +Add argument `--test` to run test of this example. + +```bash +$ cargo run --example rnn --release -- --test +``` + +Test models are generated by Pytorch [LSTM](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html) and [GRU](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html). These models include input and output tensors and can be downloaded from [here](https://huggingface.co/kigichang/test_rnn). + +Test models are generated by the following codes: + +- lstm_test.pt: A simple LSTM model with 1 layer. + + ```python + import torch + import torch.nn as nn + + rnn = nn.LSTM(10, 20, num_layers=1, batch_first=True) + input = torch.randn(5, 3, 10) + output, (hn, cn) = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + state_dict['cn'] = cn + torch.save(state_dict, "lstm_test.pt") + ``` + +- gru_test.pt: A simple GRU model with 1 layer. + + ```python + import torch + import torch.nn as nn + + rnn = nn.GRU(10, 20, num_layers=1, batch_first=True) + input = torch.randn(5, 3, 10) + output, hn = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + torch.save(state_dict, "gru_test.pt") + ``` + +- bi_lstm_test.pt: A bidirectional LSTM model with 1 layer. + + ```python + import torch + import torch.nn as nn + + rnn = nn.LSTM(10, 20, num_layers=1, bidirectional=True, batch_first=True) + input = torch.randn(5, 3, 10) + output, (hn, cn) = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + state_dict['cn'] = cn + torch.save(state_dict, "bi_lstm_test.pt") + ``` + +- bi_gru_test.pt: A bidirectional GRU model with 1 layer. + + ```python + import torch + import torch.nn as nn + + rnn = nn.GRU(10, 20, num_layers=1, bidirectional=True, batch_first=True) + input = torch.randn(5, 3, 10) + output, hn = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + torch.save(state_dict, "bi_gru_test.pt") + ``` + +- lstm_nlayer_test.pt: A LSTM model with 3 layers. + + ```python + import torch + import torch.nn as nn + + rnn = nn.LSTM(10, 20, num_layers=3, batch_first=True) + input = torch.randn(5, 3, 10) + output, (hn, cn) = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + state_dict['cn'] = cn + torch.save(state_dict, "lstm_nlayer_test.pt") + ``` + +- bi_lstm_nlayer_test.pt: A bidirectional LSTM model with 3 layers. + + ```python + import torch + import torch.nn as nn + + rnn = nn.LSTM(10, 20, num_layers=3, bidirectional=True, batch_first=True) + input = torch.randn(5, 3, 10) + output, (hn, cn) = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + state_dict['cn'] = cn + torch.save(state_dict, "bi_lstm_nlayer_test.pt") + ``` + +- gru_nlayer_test.pt: A GRU model with 3 layers. + + ```python + import torch + import torch.nn as nn + + rnn = nn.GRU(10, 20, num_layers=3, batch_first=True) + input = torch.randn(5, 3, 10) + output, hn = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + torch.save(state_dict, "gru_nlayer_test.pt") + ``` + +- bi_gru_nlayer_test.pt: A bidirectional GRU model with 3 layers. + + ```python + import torch + import torch.nn as nn + + rnn = nn.GRU(10, 20, num_layers=3, bidirectional=True, batch_first=True) + input = torch.randn(5, 3, 10) + output, hn = rnn(input) + + state_dict = rnn.state_dict() + state_dict['input'] = input + state_dict['output'] = output.contiguous() + state_dict['hn'] = hn + torch.save(state_dict, "bi_gru_nlayer_test.pt") + ``` diff --git a/candle-examples/examples/rnn/main.rs b/candle-examples/examples/rnn/main.rs index 7998ce893d..83364a8ad0 100644 --- a/candle-examples/examples/rnn/main.rs +++ b/candle-examples/examples/rnn/main.rs @@ -5,9 +5,12 @@ extern crate intel_mkl_src; extern crate accelerate_src; use anyhow::Result; -use candle::{DType, Tensor}; -use candle_nn::{rnn, LSTMConfig, RNN}; +use candle::{DType, Device, Tensor, D}; +use candle_nn::{rnn, LSTMConfig, VarBuilder, RNN}; use clap::Parser; +use hf_hub::{api::sync::Api, Repo, RepoType}; + +const ACCURACY: f32 = 1e-6; #[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)] enum WhichModel { @@ -17,7 +20,7 @@ enum WhichModel { GRU, } -#[derive(Debug, Parser)] +#[derive(Clone, Copy, Debug, Parser)] #[command(author, version, about, long_about = None)] struct Args { #[arg(long)] @@ -43,6 +46,102 @@ struct Args { #[arg(long, default_value = "lstm")] model: WhichModel, + + #[arg(long)] + test: bool, +} + +impl Args { + pub fn load_model(&self) -> Result<(Config, VarBuilder<'static>, Tensor)> { + let device = self.device()?; + if self.test { + // run unit test and download model from huggingface hub. + let model = match self.model { + WhichModel::LSTM => "lstm", + WhichModel::GRU => "gru", + }; + + let bidirection = if self.bidirection { "bi_" } else { "" }; + let layer = if self.layers > 1 { "_nlayer" } else { "" }; + let model = format!("{}{}{}_test", bidirection, model, layer); + let (config, vb) = load_model(&model, &device)?; + let input = vb.get( + (config.batch_size, config.sequence_length, config.input), + "input", + )?; + Ok((config, vb, input)) + } else { + let map = candle_nn::VarMap::new(); + let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device); + let input = Tensor::randn( + 0.0_f32, + 1.0, + (self.batch_size, self.seq_len, self.input_dim), + &device, + )?; + Ok((self.into(), vb, input)) + } + } + + pub fn device(&self) -> Result { + Ok(candle_examples::device(self.cpu)?) + } +} + +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +struct Config { + pub input: usize, + pub batch_size: usize, + pub sequence_length: usize, + pub hidden: usize, + pub layers: usize, + pub bidirection: bool, +} + +impl From<&Args> for Config { + fn from(args: &Args) -> Self { + Config { + input: args.input_dim, + batch_size: args.batch_size, + sequence_length: args.seq_len, + hidden: args.hidden_dim, + layers: args.layers, + bidirection: args.bidirection, + } + } +} + +fn load_model(model: &str, device: &Device) -> Result<(Config, VarBuilder<'static>)> { + let api = Api::new()?; + let repo_id = "kigichang/test_rnn".to_string(); + let repo = api.repo(Repo::with_revision( + repo_id, + RepoType::Model, + "main".to_string(), + )); + + let filename = repo.get(&format!("{}.pt", model))?; + let config_file = repo.get(&format!("{}.json", model))?; + + let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?; + let vb = VarBuilder::from_pth(filename, DType::F32, device)?; + + Ok((config, vb)) +} + +fn assert_tensor(a: &Tensor, b: &Tensor, v: f32) -> Result<()> { + assert_eq!(a.dims(), b.dims()); + let dim = a.dims().len(); + let mut t = (a - b)?.abs()?; + + for _i in 0..dim { + t = t.max(D::Minus1)?; + } + + let t = t.to_scalar::()?; + println!("max diff = {}", t); + assert!(t < v); + Ok(()) } fn lstm_config(layer_idx: usize, direction: rnn::Direction) -> LSTMConfig { @@ -60,142 +159,139 @@ fn gru_config(layer_idx: usize, direction: rnn::Direction) -> rnn::GRUConfig { } fn run_lstm(args: Args) -> Result { - let device = candle_examples::device(args.cpu)?; - let map = candle_nn::VarMap::new(); - let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device); + let (config, vb, mut input) = args.load_model()?; - let mut layers = Vec::with_capacity(args.layers); + let mut layers = Vec::with_capacity(config.layers); - for layer_idx in 0..args.layers { + for layer_idx in 0..config.layers { let input_dim = if layer_idx == 0 { - args.input_dim + config.input } else { - args.hidden_dim + config.hidden }; - let config = lstm_config(layer_idx, rnn::Direction::Forward); - let lstm = candle_nn::lstm(input_dim, args.hidden_dim, config, vb.clone())?; + let lstm_config = lstm_config(layer_idx, rnn::Direction::Forward); + let lstm = candle_nn::lstm(input_dim, config.hidden, lstm_config, vb.clone())?; layers.push(lstm); } - let mut input = Tensor::randn( - 0.0_f32, - 1.0, - (args.batch_size, args.seq_len, args.input_dim), - &device, - )?; - for layer in &layers { let states = layer.seq(&input)?; input = layer.states_to_tensor(&states)?; } + if args.test { + let answer = vb.get( + (config.batch_size, config.sequence_length, config.hidden), + "output", + )?; + assert_tensor(&input, &answer, ACCURACY)?; + } + Ok(input) } fn run_gru(args: Args) -> Result { - let device = candle_examples::device(args.cpu)?; - let map = candle_nn::VarMap::new(); - let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device); + let (config, vb, mut input) = args.load_model()?; - let mut layers = Vec::with_capacity(args.layers); + let mut layers = Vec::with_capacity(config.layers); - for layer_idx in 0..args.layers { + for layer_idx in 0..config.layers { let input_dim = if layer_idx == 0 { - args.input_dim + config.input } else { - args.hidden_dim + config.hidden }; - let config = gru_config(layer_idx, rnn::Direction::Forward); - let gru = candle_nn::gru(input_dim, args.hidden_dim, config, vb.clone())?; + let gru_config = gru_config(layer_idx, rnn::Direction::Forward); + let gru = candle_nn::gru(input_dim, config.hidden, gru_config, vb.clone())?; layers.push(gru); } - let mut input = Tensor::randn( - 0.0_f32, - 1.0, - (args.batch_size, args.seq_len, args.input_dim), - &device, - )?; - for layer in &layers { let states = layer.seq(&input)?; input = layer.states_to_tensor(&states)?; } + if args.test { + let answer = vb.get( + (config.batch_size, config.sequence_length, config.hidden), + "output", + )?; + assert_tensor(&input, &answer, ACCURACY)?; + } + Ok(input) } fn run_bidirectional_lstm(args: Args) -> Result { - let device = candle_examples::device(args.cpu)?; - let map = candle_nn::VarMap::new(); - let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device); + let (config, vb, mut input) = args.load_model()?; - let mut layers = Vec::with_capacity(args.layers); + let mut layers = Vec::with_capacity(config.layers); - for layer_idx in 0..args.layers { + for layer_idx in 0..config.layers { let input_dim = if layer_idx == 0 { - args.input_dim + config.input } else { - args.hidden_dim * 2 + config.hidden * 2 }; let forward_config = lstm_config(layer_idx, rnn::Direction::Forward); - let forward = candle_nn::lstm(input_dim, args.hidden_dim, forward_config, vb.clone())?; + let forward = candle_nn::lstm(input_dim, config.hidden, forward_config, vb.clone())?; let backward_config = lstm_config(layer_idx, rnn::Direction::Backward); - let backward = candle_nn::lstm(input_dim, args.hidden_dim, backward_config, vb.clone())?; + let backward = candle_nn::lstm(input_dim, config.hidden, backward_config, vb.clone())?; layers.push((forward, backward)); } - let mut input = Tensor::randn( - 0.0_f32, - 1.0, - (args.batch_size, args.seq_len, args.input_dim), - &device, - )?; - for (forward, backward) in &layers { let forward_states = forward.seq(&input)?; let backward_states = backward.seq(&input)?; - input = forward.combine_states_to_tensor(&forward_states, &backward_states)?; + input = forward.bidirectional_states_to_tensor(&forward_states, &backward_states)?; + } + + if args.test { + let answer = vb.get( + (config.batch_size, config.sequence_length, config.hidden * 2), + "output", + )?; + assert_tensor(&input, &answer, ACCURACY)?; } + Ok(input) } fn run_bidirectional_gru(args: Args) -> Result { - let device = candle_examples::device(args.cpu)?; - let map = candle_nn::VarMap::new(); - let vb = candle_nn::VarBuilder::from_varmap(&map, DType::F32, &device); + let (config, vb, mut input) = args.load_model()?; - let mut layers = Vec::with_capacity(args.layers); - for layer_idx in 0..args.layers { + let mut layers = Vec::with_capacity(config.layers); + for layer_idx in 0..config.layers { let input_dim = if layer_idx == 0 { - args.input_dim + config.input } else { - args.hidden_dim * 2 + config.hidden * 2 }; let forward_config = gru_config(layer_idx, rnn::Direction::Forward); - let forward = candle_nn::gru(input_dim, args.hidden_dim, forward_config, vb.clone())?; + let forward = candle_nn::gru(input_dim, config.hidden, forward_config, vb.clone())?; let backward_config = gru_config(layer_idx, rnn::Direction::Backward); - let backward = candle_nn::gru(input_dim, args.hidden_dim, backward_config, vb.clone())?; + let backward = candle_nn::gru(input_dim, config.hidden, backward_config, vb.clone())?; layers.push((forward, backward)); } - let mut input = Tensor::randn( - 0.0_f32, - 1.0, - (args.batch_size, args.seq_len, args.input_dim), - &device, - )?; - for (forward, backward) in &layers { let forward_states = forward.seq(&input)?; let backward_states = backward.seq(&input)?; - input = forward.combine_states_to_tensor(&forward_states, &backward_states)?; + input = forward.bidirectional_states_to_tensor(&forward_states, &backward_states)?; + } + + if args.test { + let answer = vb.get( + (config.batch_size, config.sequence_length, config.hidden * 2), + "output", + )?; + assert_tensor(&input, &answer, ACCURACY)?; } Ok(input) @@ -203,24 +299,103 @@ fn run_bidirectional_gru(args: Args) -> Result { fn main() -> Result<()> { let args = Args::parse(); - let runs = if args.bidirection { 2 } else { 1 }; - let batch_size = args.batch_size; - let seq_len = args.seq_len; - let hidden_dim = args.hidden_dim; println!( - "Running {:?} bidirection: {} layers: {}", - args.model, args.bidirection, args.layers + "Running {:?} bidirection: {} layers: {} example-test: {}", + args.model, args.bidirection, args.layers, args.test ); - let output = match (args.model, args.bidirection) { - (WhichModel::LSTM, false) => run_lstm(args), - (WhichModel::GRU, false) => run_gru(args), - (WhichModel::LSTM, true) => run_bidirectional_lstm(args), - (WhichModel::GRU, true) => run_bidirectional_gru(args), - }?; - - assert_eq!(output.dims3()?, (batch_size, seq_len, hidden_dim * runs)); + if args.test { + let test_args = Args { + model: WhichModel::LSTM, + bidirection: false, + layers: 1, + ..args + }; + print!("Testing LSTM with 1 layer: "); + run_lstm(test_args)?; + + let test_args = Args { + model: WhichModel::GRU, + bidirection: false, + layers: 1, + ..args + }; + print!("Testing GRU with 1 layer: "); + run_gru(test_args)?; + + let test_args = Args { + model: WhichModel::LSTM, + bidirection: true, + layers: 1, + ..args + }; + print!("Testing bidirectional LSTM with 1 layer: "); + run_bidirectional_lstm(test_args)?; + + let test_args = Args { + model: WhichModel::GRU, + bidirection: true, + layers: 1, + ..args + }; + print!("Testing bidirectional GRU with 1 layer: "); + run_bidirectional_gru(test_args)?; + + let test_args = Args { + model: WhichModel::LSTM, + bidirection: false, + layers: 3, + ..args + }; + print!("Testing LSTM with 3 layers: "); + run_lstm(test_args)?; + + let test_args = Args { + model: WhichModel::GRU, + bidirection: false, + layers: 3, + ..args + }; + print!("Testing GRU with 3 layers: "); + run_gru(test_args)?; + + let test_args = Args { + model: WhichModel::LSTM, + bidirection: true, + layers: 3, + ..args + }; + print!("Testing bidirectional LSTM with 3 layers: "); + run_bidirectional_lstm(test_args)?; + + let test_args = Args { + model: WhichModel::GRU, + bidirection: true, + layers: 3, + ..args + }; + print!("Testing bidirectional GRU with 3 layers: "); + run_bidirectional_gru(test_args)?; + } else { + let num_directions = if args.bidirection { 2 } else { 1 }; + let batch_size = args.batch_size; + let seq_len = args.seq_len; + let hidden_dim = args.hidden_dim; + + let output = match (args.model, args.bidirection) { + (WhichModel::LSTM, false) => run_lstm(args), + (WhichModel::GRU, false) => run_gru(args), + (WhichModel::LSTM, true) => run_bidirectional_lstm(args), + (WhichModel::GRU, true) => run_bidirectional_gru(args), + }?; + + assert_eq!( + output.dims3()?, + (batch_size, seq_len, hidden_dim * num_directions) + ); + println!("result dims: {:?}", output.dims()); + } Ok(()) } From 096ab15771d91f2a49cb92cc5679fcec68b584aa Mon Sep 17 00:00:00 2001 From: kigi Date: Sat, 12 Oct 2024 11:08:12 +0800 Subject: [PATCH 6/6] update readme --- candle-examples/examples/rnn/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-examples/examples/rnn/README.md b/candle-examples/examples/rnn/README.md index 567c9a96b2..5ba378f892 100644 --- a/candle-examples/examples/rnn/README.md +++ b/candle-examples/examples/rnn/README.md @@ -22,7 +22,7 @@ Add argument `--test` to run test of this example. $ cargo run --example rnn --release -- --test ``` -Test models are generated by Pytorch [LSTM](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html) and [GRU](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html). These models include input and output tensors and can be downloaded from [here](https://huggingface.co/kigichang/test_rnn). +Test models are generated with reference to the Pytorch examples [LSTM](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html) and [GRU](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html). These models include input and output tensors and can be downloaded from [here](https://huggingface.co/kigichang/test_rnn). Test models are generated by the following codes: