diff options
author | Justin Sing <32938975+singjc@users.noreply.github.com> | 2024-09-30 16:44:07 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-30 22:44:07 +0200 |
commit | aa35bf2ff5edd9c3534fd7744b333a1abaed4406 (patch) | |
tree | 7a02224a3ff22a3acd8d0dd752285a0786e22c20 /candle-nn | |
parent | 724650446cc729c77f7c2f8c0162c526ebcc90c3 (diff) | |
download | candle-aa35bf2ff5edd9c3534fd7744b333a1abaed4406.tar.gz candle-aa35bf2ff5edd9c3534fd7744b333a1abaed4406.tar.bz2 candle-aa35bf2ff5edd9c3534fd7744b333a1abaed4406.zip |
Add/lstm direction (#2455)
* add: direction for lstm layer
* lint: remove unused Error import
* refactor: remove unnecessary int assignment to Direction enum:
* refactor: use &'static str type instead of String for direction_str:
* Run cargofmt.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/rnn.rs | 33 |
1 files changed, 25 insertions, 8 deletions
diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index 5934af85..b4b443c6 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -70,6 +70,12 @@ impl LSTMState { } } +#[derive(Debug, Clone, Copy)] +pub enum Direction { + Forward, + Backward, +} + #[allow(clippy::upper_case_acronyms)] #[derive(Debug, Clone, Copy)] pub struct LSTMConfig { @@ -78,6 +84,7 @@ pub struct LSTMConfig { pub b_ih_init: Option<super::Init>, pub b_hh_init: Option<super::Init>, pub layer_idx: usize, + pub direction: Direction, } impl Default for LSTMConfig { @@ -88,6 +95,7 @@ impl Default for LSTMConfig { b_ih_init: Some(super::Init::Const(0.)), b_hh_init: Some(super::Init::Const(0.)), layer_idx: 0, + direction: Direction::Forward, } } } @@ -100,6 +108,7 @@ impl LSTMConfig { b_ih_init: None, b_hh_init: None, layer_idx: 0, + direction: Direction::Forward, } } } @@ -128,26 +137,34 @@ pub fn lstm( vb: crate::VarBuilder, ) -> Result<LSTM> { let layer_idx = config.layer_idx; + let direction_str = match config.direction { + Direction::Forward => "", + Direction::Backward => "_reverse", + }; let w_ih = vb.get_with_hints( (4 * hidden_dim, in_dim), - &format!("weight_ih_l{layer_idx}"), // Only a single layer is supported. + &format!("weight_ih_l{layer_idx}{direction_str}"), // Only a single layer is supported. config.w_ih_init, )?; let w_hh = vb.get_with_hints( (4 * hidden_dim, hidden_dim), - &format!("weight_hh_l{layer_idx}"), // Only a single layer is supported. + &format!("weight_hh_l{layer_idx}{direction_str}"), // Only a single layer is supported. config.w_hh_init, )?; let b_ih = match config.b_ih_init { - Some(init) => { - Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_ih_l{layer_idx}"), init)?) - } + Some(init) => Some(vb.get_with_hints( + 4 * hidden_dim, + &format!("bias_ih_l{layer_idx}{direction_str}"), + init, + )?), None => None, }; let b_hh = match config.b_hh_init { - Some(init) => { - Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_hh_l{layer_idx}"), init)?) - } + Some(init) => Some(vb.get_with_hints( + 4 * hidden_dim, + &format!("bias_hh_l{layer_idx}{direction_str}"), + init, + )?), None => None, }; Ok(LSTM { |