summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorJustin Sing <32938975+singjc@users.noreply.github.com>2024-09-30 16:44:07 -0400
committerGitHub <noreply@github.com>2024-09-30 22:44:07 +0200
commitaa35bf2ff5edd9c3534fd7744b333a1abaed4406 (patch)
tree7a02224a3ff22a3acd8d0dd752285a0786e22c20 /candle-nn
parent724650446cc729c77f7c2f8c0162c526ebcc90c3 (diff)
downloadcandle-aa35bf2ff5edd9c3534fd7744b333a1abaed4406.tar.gz
candle-aa35bf2ff5edd9c3534fd7744b333a1abaed4406.tar.bz2
candle-aa35bf2ff5edd9c3534fd7744b333a1abaed4406.zip
Add/lstm direction (#2455)
* add: direction for lstm layer * lint: remove unused Error import * refactor: remove unnecessary int assignment to Direction enum: * refactor: use &'static str type instead of String for direction_str: * Run cargofmt. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/rnn.rs33
1 files changed, 25 insertions, 8 deletions
diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs
index 5934af85..b4b443c6 100644
--- a/candle-nn/src/rnn.rs
+++ b/candle-nn/src/rnn.rs
@@ -70,6 +70,12 @@ impl LSTMState {
}
}
+#[derive(Debug, Clone, Copy)]
+pub enum Direction {
+ Forward,
+ Backward,
+}
+
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Clone, Copy)]
pub struct LSTMConfig {
@@ -78,6 +84,7 @@ pub struct LSTMConfig {
pub b_ih_init: Option<super::Init>,
pub b_hh_init: Option<super::Init>,
pub layer_idx: usize,
+ pub direction: Direction,
}
impl Default for LSTMConfig {
@@ -88,6 +95,7 @@ impl Default for LSTMConfig {
b_ih_init: Some(super::Init::Const(0.)),
b_hh_init: Some(super::Init::Const(0.)),
layer_idx: 0,
+ direction: Direction::Forward,
}
}
}
@@ -100,6 +108,7 @@ impl LSTMConfig {
b_ih_init: None,
b_hh_init: None,
layer_idx: 0,
+ direction: Direction::Forward,
}
}
}
@@ -128,26 +137,34 @@ pub fn lstm(
vb: crate::VarBuilder,
) -> Result<LSTM> {
let layer_idx = config.layer_idx;
+ let direction_str = match config.direction {
+ Direction::Forward => "",
+ Direction::Backward => "_reverse",
+ };
let w_ih = vb.get_with_hints(
(4 * hidden_dim, in_dim),
- &format!("weight_ih_l{layer_idx}"), // Only a single layer is supported.
+ &format!("weight_ih_l{layer_idx}{direction_str}"), // Only a single layer is supported.
config.w_ih_init,
)?;
let w_hh = vb.get_with_hints(
(4 * hidden_dim, hidden_dim),
- &format!("weight_hh_l{layer_idx}"), // Only a single layer is supported.
+ &format!("weight_hh_l{layer_idx}{direction_str}"), // Only a single layer is supported.
config.w_hh_init,
)?;
let b_ih = match config.b_ih_init {
- Some(init) => {
- Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_ih_l{layer_idx}"), init)?)
- }
+ Some(init) => Some(vb.get_with_hints(
+ 4 * hidden_dim,
+ &format!("bias_ih_l{layer_idx}{direction_str}"),
+ init,
+ )?),
None => None,
};
let b_hh = match config.b_hh_init {
- Some(init) => {
- Some(vb.get_with_hints(4 * hidden_dim, &format!("bias_hh_l{layer_idx}"), init)?)
- }
+ Some(init) => Some(vb.get_with_hints(
+ 4 * hidden_dim,
+ &format!("bias_hh_l{layer_idx}{direction_str}"),
+ init,
+ )?),
None => None,
};
Ok(LSTM {