diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-30 13:27:09 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-30 13:27:09 +0100 |
commit | f35b9f6baa58e78bcd620025c2467abd53e2d2bc (patch) | |
tree | 26be9360cbad12dc56500d2fe39d932f2a3976a1 | |
parent | 618f4e4c788959d3d3c471b1d0f92594176b7e1b (diff) | |
download | candle-f35b9f6baa58e78bcd620025c2467abd53e2d2bc.tar.gz candle-f35b9f6baa58e78bcd620025c2467abd53e2d2bc.tar.bz2 candle-f35b9f6baa58e78bcd620025c2467abd53e2d2bc.zip |
Add some recurrent neural networks (#674)
* Add the rnn module.
* More LSTM.
* Implement the RNN forward pass.
* More forward pass for LSTM.
-rw-r--r-- | candle-nn/src/lib.rs | 2 | ||||
-rw-r--r-- | candle-nn/src/rnn.rs | 188 |
2 files changed, 190 insertions, 0 deletions
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 2e2c2545..8ab51070 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -10,6 +10,7 @@ pub mod linear; pub mod loss; pub mod ops; pub mod optim; +pub mod rnn; pub mod var_builder; pub mod var_map; @@ -23,6 +24,7 @@ pub use init::Init; pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; pub use linear::{linear, linear_no_bias, Linear}; pub use optim::{AdamW, ParamsAdamW, SGD}; +pub use rnn::{lstm, LSTM, RNN}; pub use var_builder::VarBuilder; pub use var_map::VarMap; diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs new file mode 100644 index 00000000..4b116081 --- /dev/null +++ b/candle-nn/src/rnn.rs @@ -0,0 +1,188 @@ +//! Recurrent Neural Networks +use candle::{DType, Device, IndexOp, Result, Tensor}; + +/// Trait for Recurrent Neural Networks. +#[allow(clippy::upper_case_acronyms)] +pub trait RNN { + type State; + + /// A zero state from which the recurrent network is usually initialized. + fn zero_state(&self, batch_dim: usize) -> Result<Self::State>; + + /// Applies a single step of the recurrent network. + /// + /// The input should have dimensions [batch_size, features]. + fn step(&self, input: &Tensor, state: &Self::State) -> Result<Self::State>; + + /// Applies multiple steps of the recurrent network. + /// + /// The input should have dimensions [batch_size, seq_len, features]. + /// The initial state is the result of applying zero_state. + fn seq(&self, input: &Tensor) -> Result<(Tensor, Self::State)> { + let batch_dim = input.dim(0)?; + let state = self.zero_state(batch_dim)?; + self.seq_init(input, &state) + } + + /// Applies multiple steps of the recurrent network. + /// + /// The input should have dimensions [batch_size, seq_len, features]. + fn seq_init(&self, input: &Tensor, state: &Self::State) -> Result<(Tensor, Self::State)>; +} + +/// The state for a LSTM network, this contains two tensors. +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Clone)] +pub struct LSTMState { + h: Tensor, + c: Tensor, +} + +impl LSTMState { + /// The hidden state vector, which is also the output of the LSTM. + pub fn h(&self) -> &Tensor { + &self.h + } + + /// The cell state vector. + pub fn c(&self) -> &Tensor { + &self.c + } +} + +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Clone, Copy)] +pub struct LSTMConfig { + pub w_ih_init: super::Init, + pub w_hh_init: super::Init, + pub b_ih_init: Option<super::Init>, + pub b_hh_init: Option<super::Init>, +} + +impl Default for LSTMConfig { + fn default() -> Self { + Self { + w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM, + w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM, + b_ih_init: Some(super::Init::Const(0.)), + b_hh_init: Some(super::Init::Const(0.)), + } + } +} + +impl LSTMConfig { + pub fn default_no_bias() -> Self { + Self { + w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM, + w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM, + b_ih_init: None, + b_hh_init: None, + } + } +} + +/// A Long Short-Term Memory (LSTM) layer. +/// +/// <https://en.wikipedia.org/wiki/Long_short-term_memory> +#[allow(clippy::upper_case_acronyms, unused)] +#[derive(Debug)] +pub struct LSTM { + w_ih: Tensor, + w_hh: Tensor, + b_ih: Option<Tensor>, + b_hh: Option<Tensor>, + hidden_dim: usize, + config: LSTMConfig, + device: Device, + dtype: DType, +} + +/// Creates a LSTM layer. +pub fn lstm( + in_dim: usize, + hidden_dim: usize, + config: LSTMConfig, + vb: crate::VarBuilder, +) -> Result<LSTM> { + let w_ih = vb.get_with_hints( + (4 * hidden_dim, in_dim), + "weight_ih_l0", // Only a single layer is supported. + config.w_ih_init, + )?; + let w_hh = vb.get_with_hints( + (4 * hidden_dim, in_dim), + "weight_hh_l0", // Only a single layer is supported. + config.w_hh_init, + )?; + let b_ih = match config.b_ih_init { + Some(init) => Some(vb.get_with_hints(4 * hidden_dim, "bias_ih_l0", init)?), + None => None, + }; + let b_hh = match config.b_hh_init { + Some(init) => Some(vb.get_with_hints(4 * hidden_dim, "bias_hh_l0", init)?), + None => None, + }; + Ok(LSTM { + w_ih, + w_hh, + b_ih, + b_hh, + hidden_dim, + config, + device: vb.device().clone(), + dtype: vb.dtype(), + }) +} + +impl RNN for LSTM { + type State = LSTMState; + + fn zero_state(&self, batch_dim: usize) -> Result<Self::State> { + let zeros = Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?; + Ok(Self::State { + h: zeros.clone(), + c: zeros.clone(), + }) + } + + fn step(&self, input: &Tensor, in_state: &Self::State) -> Result<Self::State> { + let w_ih = input.matmul(&self.w_ih.t()?)?; + let w_hh = in_state.h.matmul(&self.w_hh.t()?)?; + let w_ih = match &self.b_ih { + None => w_ih, + Some(b_ih) => w_ih.broadcast_add(b_ih)?, + }; + let w_hh = match &self.b_hh { + None => w_hh, + Some(b_hh) => w_hh.broadcast_add(b_hh)?, + }; + let chunks = (&w_ih + &w_hh)?.chunk(4, 1)?; + let in_gate = crate::ops::sigmoid(&chunks[0])?; + let forget_gate = crate::ops::sigmoid(&chunks[1])?; + // TODO: This should be a tanh + let cell_gate = crate::ops::sigmoid(&chunks[2])?; + let out_gate = crate::ops::sigmoid(&chunks[3])?; + + let next_c = ((forget_gate * &in_state.c)? + (in_gate * cell_gate)?)?; + // TODO: This should be another tanh + let next_h = (out_gate * crate::ops::sigmoid(&next_c)?)?; + Ok(LSTMState { + c: next_c, + h: next_h, + }) + } + + /// The input should have dimensions [batch_size, seq_len, features]. + fn seq_init(&self, input: &Tensor, in_state: &Self::State) -> Result<(Tensor, Self::State)> { + let (_b_size, seq_len, _features) = input.dims3()?; + let mut state = in_state.clone(); + let mut output: Vec<Tensor> = Vec::with_capacity(seq_len); + for seq_index in 0..seq_len { + let input = input.i((.., seq_index, ..))?; + state = self.step(&input, &state)?; + output.push(state.h.clone()); + } + let output = Tensor::cat(&output, 1)?; + Ok((output, state)) + } +} |