diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-31 09:43:10 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-31 08:43:10 +0100 |
commit | db598160876a841fcf372ae923e4c5328a47f653 (patch) | |
tree | 8d927dad45d515859d8e5b8a5f4c5822fafcc636 | |
parent | d210c71d77a6044c2a42c2e75487b6180e957158 (diff) | |
download | candle-db598160876a841fcf372ae923e4c5328a47f653.tar.gz candle-db598160876a841fcf372ae923e4c5328a47f653.tar.bz2 candle-db598160876a841fcf372ae923e4c5328a47f653.zip |
Add a GRU layer. (#688)
* Add a GRU layer.
* Fix the n gate computation.
-rw-r--r-- | candle-nn/src/lib.rs | 2 | ||||
-rw-r--r-- | candle-nn/src/rnn.rs | 142 | ||||
-rw-r--r-- | candle-nn/tests/rnn.rs | 44 |
3 files changed, 187 insertions, 1 deletions
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index e9552e83..48046081 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -25,7 +25,7 @@ pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; pub use linear::{linear, linear_no_bias, Linear}; pub use ops::Dropout; pub use optim::{AdamW, ParamsAdamW, SGD}; -pub use rnn::{lstm, LSTM, RNN}; +pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN}; pub use var_builder::VarBuilder; pub use var_map::VarMap; diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index 681f2b2b..06aaf190 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -184,3 +184,145 @@ impl RNN for LSTM { Ok((output, state)) } } + +/// The state for a GRU network, this contains a single tensor. +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Clone)] +pub struct GRUState { + h: Tensor, +} + +impl GRUState { + /// The hidden state vector, which is also the output of the LSTM. + pub fn h(&self) -> &Tensor { + &self.h + } +} + +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Clone, Copy)] +pub struct GRUConfig { + pub w_ih_init: super::Init, + pub w_hh_init: super::Init, + pub b_ih_init: Option<super::Init>, + pub b_hh_init: Option<super::Init>, +} + +impl Default for GRUConfig { + fn default() -> Self { + Self { + w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM, + w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM, + b_ih_init: Some(super::Init::Const(0.)), + b_hh_init: Some(super::Init::Const(0.)), + } + } +} + +impl GRUConfig { + pub fn default_no_bias() -> Self { + Self { + w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM, + w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM, + b_ih_init: None, + b_hh_init: None, + } + } +} + +/// A Gated Recurrent Unit (GRU) layer. +/// +/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit> +#[allow(clippy::upper_case_acronyms, unused)] +#[derive(Debug)] +pub struct GRU { + w_ih: Tensor, + w_hh: Tensor, + b_ih: Option<Tensor>, + b_hh: Option<Tensor>, + hidden_dim: usize, + config: GRUConfig, + device: Device, + dtype: DType, +} + +/// Creates a GRU layer. +pub fn gru( + in_dim: usize, + hidden_dim: usize, + config: GRUConfig, + vb: crate::VarBuilder, +) -> Result<GRU> { + let w_ih = vb.get_with_hints( + (3 * hidden_dim, in_dim), + "weight_ih_l0", // Only a single layer is supported. + config.w_ih_init, + )?; + let w_hh = vb.get_with_hints( + (3 * hidden_dim, hidden_dim), + "weight_hh_l0", // Only a single layer is supported. + config.w_hh_init, + )?; + let b_ih = match config.b_ih_init { + Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?), + None => None, + }; + let b_hh = match config.b_hh_init { + Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?), + None => None, + }; + Ok(GRU { + w_ih, + w_hh, + b_ih, + b_hh, + hidden_dim, + config, + device: vb.device().clone(), + dtype: vb.dtype(), + }) +} + +impl RNN for GRU { + type State = GRUState; + + fn zero_state(&self, batch_dim: usize) -> Result<Self::State> { + let h = Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?; + Ok(Self::State { h }) + } + + fn step(&self, input: &Tensor, in_state: &Self::State) -> Result<Self::State> { + let w_ih = input.matmul(&self.w_ih.t()?)?; + let w_hh = in_state.h.matmul(&self.w_hh.t()?)?; + let w_ih = match &self.b_ih { + None => w_ih, + Some(b_ih) => w_ih.broadcast_add(b_ih)?, + }; + let w_hh = match &self.b_hh { + None => w_hh, + Some(b_hh) => w_hh.broadcast_add(b_hh)?, + }; + let chunks_ih = w_ih.chunk(3, 1)?; + let chunks_hh = w_hh.chunk(3, 1)?; + let r_gate = crate::ops::sigmoid(&(&chunks_ih[0] + &chunks_hh[0])?)?; + let z_gate = crate::ops::sigmoid(&(&chunks_ih[1] + &chunks_hh[1])?)?; + let n_gate = (&chunks_ih[2] + (r_gate * &chunks_hh[2])?)?.tanh(); + + let next_h = ((&z_gate * &in_state.h)? - ((&z_gate - 1.)? * n_gate)?)?; + Ok(GRUState { h: next_h }) + } + + /// The input should have dimensions [batch_size, seq_len, features]. + fn seq_init(&self, input: &Tensor, in_state: &Self::State) -> Result<(Tensor, Self::State)> { + let (_b_size, seq_len, _features) = input.dims3()?; + let mut state = in_state.clone(); + let mut output: Vec<Tensor> = Vec::with_capacity(seq_len); + for seq_index in 0..seq_len { + let input = input.i((.., seq_index, ..))?; + state = self.step(&input, &state)?; + output.push(state.h.clone()); + } + let output = Tensor::cat(&output, 1)?; + Ok((output, state)) + } +} diff --git a/candle-nn/tests/rnn.rs b/candle-nn/tests/rnn.rs index eda1a381..498c9188 100644 --- a/candle-nn/tests/rnn.rs +++ b/candle-nn/tests/rnn.rs @@ -55,3 +55,47 @@ fn lstm() -> Result<()> { assert_eq!(to_vec2_round(c, 4)?, &[[5.725, 0.4458, -0.2908]]); Ok(()) } + +/* The following test can be verified against PyTorch using the following snippet. +import torch +from torch import nn +gru = nn.GRU(2, 3, 1) +gru.weight_ih_l0 = torch.nn.Parameter(torch.arange(0., 18.).reshape(9, 2).cos()) +gru.weight_hh_l0 = torch.nn.Parameter(torch.arange(0., 27.).reshape(9, 3).sin()) +gru.bias_ih_l0 = torch.nn.Parameter(torch.tensor([-1., 1., -0.5, 2, -1, 1, -0.5, 2, -1])) +gru.bias_hh_l0 = torch.nn.Parameter(torch.tensor([-1., 1., -0.5, 2, -1, 1, -0.5, 2, -1]).cos()) +state = torch.zeros((1, 3)) +for inp in [3., 1., 4., 1., 5., 9., 2.]: + inp = torch.tensor([[inp, inp * 0.5]]) + _out, state = gru(inp, state) +print(state) +# tensor([[ 0.0579, 0.8836, -0.9991]], grad_fn=<SqueezeBackward1>) +*/ +#[test] +fn gru() -> Result<()> { + let cpu = &Device::Cpu; + let w_ih = Tensor::arange(0f32, 18f32, cpu)?.reshape((9, 2))?; + let w_ih = w_ih.cos()?; + let w_hh = Tensor::arange(0f32, 27f32, cpu)?.reshape((9, 3))?; + let w_hh = w_hh.sin()?; + let b_ih = Tensor::new(&[-1f32, 1., -0.5, 2., -1., 1., -0.5, 2., -1.], cpu)?; + let b_hh = b_ih.cos()?; + let tensors: std::collections::HashMap<_, _> = [ + ("weight_ih_l0".to_string(), w_ih), + ("weight_hh_l0".to_string(), w_hh), + ("bias_ih_l0".to_string(), b_ih), + ("bias_hh_l0".to_string(), b_hh), + ] + .into_iter() + .collect(); + let vb = candle_nn::VarBuilder::from_tensors(tensors, DType::F32, cpu); + let gru = candle_nn::gru(2, 3, Default::default(), vb)?; + let mut state = gru.zero_state(1)?; + for inp in [3f32, 1., 4., 1., 5., 9., 2.] { + let inp = Tensor::new(&[[inp, inp * 0.5]], cpu)?; + state = gru.step(&inp, &state)? + } + let h = state.h(); + assert_eq!(to_vec2_round(h, 4)?, &[[0.0579, 0.8836, -0.9991]]); + Ok(()) +} |