diff options
Diffstat (limited to 'candle-nn/src/rnn.rs')
-rw-r--r-- | candle-nn/src/rnn.rs | 142 |
1 files changed, 142 insertions, 0 deletions
diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index 681f2b2b..06aaf190 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -184,3 +184,145 @@ impl RNN for LSTM { Ok((output, state)) } } + +/// The state for a GRU network, this contains a single tensor. +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Clone)] +pub struct GRUState { + h: Tensor, +} + +impl GRUState { + /// The hidden state vector, which is also the output of the LSTM. + pub fn h(&self) -> &Tensor { + &self.h + } +} + +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Clone, Copy)] +pub struct GRUConfig { + pub w_ih_init: super::Init, + pub w_hh_init: super::Init, + pub b_ih_init: Option<super::Init>, + pub b_hh_init: Option<super::Init>, +} + +impl Default for GRUConfig { + fn default() -> Self { + Self { + w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM, + w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM, + b_ih_init: Some(super::Init::Const(0.)), + b_hh_init: Some(super::Init::Const(0.)), + } + } +} + +impl GRUConfig { + pub fn default_no_bias() -> Self { + Self { + w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM, + w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM, + b_ih_init: None, + b_hh_init: None, + } + } +} + +/// A Gated Recurrent Unit (GRU) layer. +/// +/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit> +#[allow(clippy::upper_case_acronyms, unused)] +#[derive(Debug)] +pub struct GRU { + w_ih: Tensor, + w_hh: Tensor, + b_ih: Option<Tensor>, + b_hh: Option<Tensor>, + hidden_dim: usize, + config: GRUConfig, + device: Device, + dtype: DType, +} + +/// Creates a GRU layer. +pub fn gru( + in_dim: usize, + hidden_dim: usize, + config: GRUConfig, + vb: crate::VarBuilder, +) -> Result<GRU> { + let w_ih = vb.get_with_hints( + (3 * hidden_dim, in_dim), + "weight_ih_l0", // Only a single layer is supported. + config.w_ih_init, + )?; + let w_hh = vb.get_with_hints( + (3 * hidden_dim, hidden_dim), + "weight_hh_l0", // Only a single layer is supported. + config.w_hh_init, + )?; + let b_ih = match config.b_ih_init { + Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?), + None => None, + }; + let b_hh = match config.b_hh_init { + Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?), + None => None, + }; + Ok(GRU { + w_ih, + w_hh, + b_ih, + b_hh, + hidden_dim, + config, + device: vb.device().clone(), + dtype: vb.dtype(), + }) +} + +impl RNN for GRU { + type State = GRUState; + + fn zero_state(&self, batch_dim: usize) -> Result<Self::State> { + let h = Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?; + Ok(Self::State { h }) + } + + fn step(&self, input: &Tensor, in_state: &Self::State) -> Result<Self::State> { + let w_ih = input.matmul(&self.w_ih.t()?)?; + let w_hh = in_state.h.matmul(&self.w_hh.t()?)?; + let w_ih = match &self.b_ih { + None => w_ih, + Some(b_ih) => w_ih.broadcast_add(b_ih)?, + }; + let w_hh = match &self.b_hh { + None => w_hh, + Some(b_hh) => w_hh.broadcast_add(b_hh)?, + }; + let chunks_ih = w_ih.chunk(3, 1)?; + let chunks_hh = w_hh.chunk(3, 1)?; + let r_gate = crate::ops::sigmoid(&(&chunks_ih[0] + &chunks_hh[0])?)?; + let z_gate = crate::ops::sigmoid(&(&chunks_ih[1] + &chunks_hh[1])?)?; + let n_gate = (&chunks_ih[2] + (r_gate * &chunks_hh[2])?)?.tanh(); + + let next_h = ((&z_gate * &in_state.h)? - ((&z_gate - 1.)? * n_gate)?)?; + Ok(GRUState { h: next_h }) + } + + /// The input should have dimensions [batch_size, seq_len, features]. + fn seq_init(&self, input: &Tensor, in_state: &Self::State) -> Result<(Tensor, Self::State)> { + let (_b_size, seq_len, _features) = input.dims3()?; + let mut state = in_state.clone(); + let mut output: Vec<Tensor> = Vec::with_capacity(seq_len); + for seq_index in 0..seq_len { + let input = input.i((.., seq_index, ..))?; + state = self.step(&input, &state)?; + output.push(state.h.clone()); + } + let output = Tensor::cat(&output, 1)?; + Ok((output, state)) + } +} |