diff options
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/rnn.rs | 33 |
1 files changed, 25 insertions, 8 deletions
diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index 5934af85..b4b443c6 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -70,6 +70,12 @@ impl LSTMState { } } +#[derive(Debug, Clone, Copy)] +pub enum Direction { + Forward, + Backward, +} + #[allow(clippy::upper_case_acronyms)] #[derive(Debug, Clone, Copy)] pub struct LSTMConfig { @@ -78,6 +84,7 @@ pub struct LSTMConfig { pub b_ih_init: Option<super::Init>, pub b_hh_init: Option<super::Init>, pub layer_idx: usize, + pub direction: Direction, } impl Default for LSTMConfig { @@ -88,6 +95,7 @@ impl Default for LSTMConfig { b_ih_init: Some(super::Init::Const(0.)), b_hh_init: Some(super::Init::Const(0.)), layer_idx: 0, + direction: Direction::Forward, } } } @@ -100,6 +108,7 @@ impl LSTMConfig { b_ih_init: None, b_hh_init: None, layer_idx: 0, + direction: Direction::Forward, } } } @@ -128,26 +137,34 @@ pub fn lstm( vb: crate::VarBuilder, ) -> Result<LSTM> { let layer_idx = config.layer_idx; + let direction_str = match config.direction { + Direction::Forward => "", + Direction::Backward => "_reverse", + }; let w_ih = vb.get_with_hints( (4 * hidden_dim, in_dim), - &format!("weight_ih_l{layer_idx}"), // Only a single layer is supported. + &format!("weight_ih_l{layer_idx}{direction_str}"), // Only a single layer is supported. config.w_ih_init, )?; let w_hh = vb.get_with_hints( (4 * hidden_dim, hidden_dim), - &format!("weight_hh_l{layer_idx}"), // Only a single layer is supported. + &format!("weight_hh_l{layer_idx}{direction_str}"), // Only a single layer is supported. config.w_hh_init, )?; let b_ih = match config.b_ih_init { - Some(init) => { - Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_ih_l{layer_idx}"), init)?) - } + Some(init) => Some(vb.get_with_hints( + 4 * hidden_dim, + &format!("bias_ih_l{layer_idx}{direction_str}"), + init, + )?), None => None, }; let b_hh = match config.b_hh_init { - Some(init) => { - Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_hh_l{layer_idx}"), init)?) - } + Some(init) => Some(vb.get_with_hints( + 4 * hidden_dim, + &format!("bias_hh_l{layer_idx}{direction_str}"), + init, + )?), None => None, }; Ok(LSTM { |