summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-31 09:43:10 +0200
committerGitHub <noreply@github.com>2023-08-31 08:43:10 +0100
commitdb598160876a841fcf372ae923e4c5328a47f653 (patch)
tree8d927dad45d515859d8e5b8a5f4c5822fafcc636
parentd210c71d77a6044c2a42c2e75487b6180e957158 (diff)
downloadcandle-db598160876a841fcf372ae923e4c5328a47f653.tar.gz
candle-db598160876a841fcf372ae923e4c5328a47f653.tar.bz2
candle-db598160876a841fcf372ae923e4c5328a47f653.zip
Add a GRU layer. (#688)
* Add a GRU layer. * Fix the n gate computation.
-rw-r--r--candle-nn/src/lib.rs2
-rw-r--r--candle-nn/src/rnn.rs142
-rw-r--r--candle-nn/tests/rnn.rs44
3 files changed, 187 insertions, 1 deletions
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index e9552e83..48046081 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -25,7 +25,7 @@ pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
pub use linear::{linear, linear_no_bias, Linear};
pub use ops::Dropout;
pub use optim::{AdamW, ParamsAdamW, SGD};
-pub use rnn::{lstm, LSTM, RNN};
+pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
pub use var_builder::VarBuilder;
pub use var_map::VarMap;
diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs
index 681f2b2b..06aaf190 100644
--- a/candle-nn/src/rnn.rs
+++ b/candle-nn/src/rnn.rs
@@ -184,3 +184,145 @@ impl RNN for LSTM {
Ok((output, state))
}
}
+
+/// The state for a GRU network, this contains a single tensor.
+#[allow(clippy::upper_case_acronyms)]
+#[derive(Debug, Clone)]
+pub struct GRUState {
+ h: Tensor,
+}
+
+impl GRUState {
+ /// The hidden state vector, which is also the output of the LSTM.
+ pub fn h(&self) -> &Tensor {
+ &self.h
+ }
+}
+
+#[allow(clippy::upper_case_acronyms)]
+#[derive(Debug, Clone, Copy)]
+pub struct GRUConfig {
+ pub w_ih_init: super::Init,
+ pub w_hh_init: super::Init,
+ pub b_ih_init: Option<super::Init>,
+ pub b_hh_init: Option<super::Init>,
+}
+
+impl Default for GRUConfig {
+ fn default() -> Self {
+ Self {
+ w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM,
+ w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
+ b_ih_init: Some(super::Init::Const(0.)),
+ b_hh_init: Some(super::Init::Const(0.)),
+ }
+ }
+}
+
+impl GRUConfig {
+ pub fn default_no_bias() -> Self {
+ Self {
+ w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM,
+ w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
+ b_ih_init: None,
+ b_hh_init: None,
+ }
+ }
+}
+
+/// A Gated Recurrent Unit (GRU) layer.
+///
+/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit>
+#[allow(clippy::upper_case_acronyms, unused)]
+#[derive(Debug)]
+pub struct GRU {
+ w_ih: Tensor,
+ w_hh: Tensor,
+ b_ih: Option<Tensor>,
+ b_hh: Option<Tensor>,
+ hidden_dim: usize,
+ config: GRUConfig,
+ device: Device,
+ dtype: DType,
+}
+
+/// Creates a GRU layer.
+pub fn gru(
+ in_dim: usize,
+ hidden_dim: usize,
+ config: GRUConfig,
+ vb: crate::VarBuilder,
+) -> Result<GRU> {
+ let w_ih = vb.get_with_hints(
+ (3 * hidden_dim, in_dim),
+ "weight_ih_l0", // Only a single layer is supported.
+ config.w_ih_init,
+ )?;
+ let w_hh = vb.get_with_hints(
+ (3 * hidden_dim, hidden_dim),
+ "weight_hh_l0", // Only a single layer is supported.
+ config.w_hh_init,
+ )?;
+ let b_ih = match config.b_ih_init {
+ Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?),
+ None => None,
+ };
+ let b_hh = match config.b_hh_init {
+ Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?),
+ None => None,
+ };
+ Ok(GRU {
+ w_ih,
+ w_hh,
+ b_ih,
+ b_hh,
+ hidden_dim,
+ config,
+ device: vb.device().clone(),
+ dtype: vb.dtype(),
+ })
+}
+
+impl RNN for GRU {
+ type State = GRUState;
+
+ fn zero_state(&self, batch_dim: usize) -> Result<Self::State> {
+ let h = Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?;
+ Ok(Self::State { h })
+ }
+
+ fn step(&self, input: &Tensor, in_state: &Self::State) -> Result<Self::State> {
+ let w_ih = input.matmul(&self.w_ih.t()?)?;
+ let w_hh = in_state.h.matmul(&self.w_hh.t()?)?;
+ let w_ih = match &self.b_ih {
+ None => w_ih,
+ Some(b_ih) => w_ih.broadcast_add(b_ih)?,
+ };
+ let w_hh = match &self.b_hh {
+ None => w_hh,
+ Some(b_hh) => w_hh.broadcast_add(b_hh)?,
+ };
+ let chunks_ih = w_ih.chunk(3, 1)?;
+ let chunks_hh = w_hh.chunk(3, 1)?;
+ let r_gate = crate::ops::sigmoid(&(&chunks_ih[0] + &chunks_hh[0])?)?;
+ let z_gate = crate::ops::sigmoid(&(&chunks_ih[1] + &chunks_hh[1])?)?;
+ let n_gate = (&chunks_ih[2] + (r_gate * &chunks_hh[2])?)?.tanh();
+
+ let next_h = ((&z_gate * &in_state.h)? - ((&z_gate - 1.)? * n_gate)?)?;
+ Ok(GRUState { h: next_h })
+ }
+
+ /// The input should have dimensions [batch_size, seq_len, features].
+ fn seq_init(&self, input: &Tensor, in_state: &Self::State) -> Result<(Tensor, Self::State)> {
+ let (_b_size, seq_len, _features) = input.dims3()?;
+ let mut state = in_state.clone();
+ let mut output: Vec<Tensor> = Vec::with_capacity(seq_len);
+ for seq_index in 0..seq_len {
+ let input = input.i((.., seq_index, ..))?;
+ state = self.step(&input, &state)?;
+ output.push(state.h.clone());
+ }
+ let output = Tensor::cat(&output, 1)?;
+ Ok((output, state))
+ }
+}
diff --git a/candle-nn/tests/rnn.rs b/candle-nn/tests/rnn.rs
index eda1a381..498c9188 100644
--- a/candle-nn/tests/rnn.rs
+++ b/candle-nn/tests/rnn.rs
@@ -55,3 +55,47 @@ fn lstm() -> Result<()> {
assert_eq!(to_vec2_round(c, 4)?, &[[5.725, 0.4458, -0.2908]]);
Ok(())
}
+
+/* The following test can be verified against PyTorch using the following snippet.
+import torch
+from torch import nn
+gru = nn.GRU(2, 3, 1)
+gru.weight_ih_l0 = torch.nn.Parameter(torch.arange(0., 18.).reshape(9, 2).cos())
+gru.weight_hh_l0 = torch.nn.Parameter(torch.arange(0., 27.).reshape(9, 3).sin())
+gru.bias_ih_l0 = torch.nn.Parameter(torch.tensor([-1., 1., -0.5, 2, -1, 1, -0.5, 2, -1]))
+gru.bias_hh_l0 = torch.nn.Parameter(torch.tensor([-1., 1., -0.5, 2, -1, 1, -0.5, 2, -1]).cos())
+state = torch.zeros((1, 3))
+for inp in [3., 1., 4., 1., 5., 9., 2.]:
+ inp = torch.tensor([[inp, inp * 0.5]])
+ _out, state = gru(inp, state)
+print(state)
+# tensor([[ 0.0579, 0.8836, -0.9991]], grad_fn=<SqueezeBackward1>)
+*/
+#[test]
+fn gru() -> Result<()> {
+ let cpu = &Device::Cpu;
+ let w_ih = Tensor::arange(0f32, 18f32, cpu)?.reshape((9, 2))?;
+ let w_ih = w_ih.cos()?;
+ let w_hh = Tensor::arange(0f32, 27f32, cpu)?.reshape((9, 3))?;
+ let w_hh = w_hh.sin()?;
+ let b_ih = Tensor::new(&[-1f32, 1., -0.5, 2., -1., 1., -0.5, 2., -1.], cpu)?;
+ let b_hh = b_ih.cos()?;
+ let tensors: std::collections::HashMap<_, _> = [
+ ("weight_ih_l0".to_string(), w_ih),
+ ("weight_hh_l0".to_string(), w_hh),
+ ("bias_ih_l0".to_string(), b_ih),
+ ("bias_hh_l0".to_string(), b_hh),
+ ]
+ .into_iter()
+ .collect();
+ let vb = candle_nn::VarBuilder::from_tensors(tensors, DType::F32, cpu);
+ let gru = candle_nn::gru(2, 3, Default::default(), vb)?;
+ let mut state = gru.zero_state(1)?;
+ for inp in [3f32, 1., 4., 1., 5., 9., 2.] {
+ let inp = Tensor::new(&[[inp, inp * 0.5]], cpu)?;
+ state = gru.step(&inp, &state)?
+ }
+ let h = state.h();
+ assert_eq!(to_vec2_round(h, 4)?, &[[0.0579, 0.8836, -0.9991]]);
+ Ok(())
+}