diff options
-rw-r--r-- | candle-nn/src/rnn.rs | 52 |
1 files changed, 25 insertions, 27 deletions
diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index 10ba48f3..9f144cca 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -4,7 +4,7 @@ use candle::{DType, Device, IndexOp, Result, Tensor}; /// Trait for Recurrent Neural Networks. #[allow(clippy::upper_case_acronyms)] pub trait RNN { - type State; + type State: Clone; /// A zero state from which the recurrent network is usually initialized. fn zero_state(&self, batch_dim: usize) -> Result<Self::State>; @@ -18,7 +18,7 @@ pub trait RNN { /// /// The input should have dimensions [batch_size, seq_len, features]. /// The initial state is the result of applying zero_state. - fn seq(&self, input: &Tensor) -> Result<(Tensor, Self::State)> { + fn seq(&self, input: &Tensor) -> Result<Vec<Self::State>> { let batch_dim = input.dim(0)?; let state = self.zero_state(batch_dim)?; self.seq_init(input, &state) @@ -27,7 +27,23 @@ pub trait RNN { /// Applies multiple steps of the recurrent network. /// /// The input should have dimensions [batch_size, seq_len, features]. - fn seq_init(&self, input: &Tensor, state: &Self::State) -> Result<(Tensor, Self::State)>; + fn seq_init(&self, input: &Tensor, init_state: &Self::State) -> Result<Vec<Self::State>> { + let (_b_size, seq_len, _features) = input.dims3()?; + let mut output = Vec::with_capacity(seq_len); + for seq_index in 0..seq_len { + let input = input.i((.., seq_index, ..))?; + let state = if seq_index == 0 { + self.step(&input, init_state)? + } else { + self.step(&input, &output[seq_index - 1])? + }; + output.push(state); + } + Ok(output) + } + + /// Converts a sequence of state to a tensor. + fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor>; } /// The state for a LSTM network, this contains two tensors. @@ -179,18 +195,9 @@ impl RNN for LSTM { }) } - /// The input should have dimensions [batch_size, seq_len, features]. - fn seq_init(&self, input: &Tensor, in_state: &Self::State) -> Result<(Tensor, Self::State)> { - let (_b_size, seq_len, _features) = input.dims3()?; - let mut state = in_state.clone(); - let mut output: Vec<Tensor> = Vec::with_capacity(seq_len); - for seq_index in 0..seq_len { - let input = input.i((.., seq_index, ..))?; - state = self.step(&input, &state)?; - output.push(state.h.clone()); - } - let output = Tensor::cat(&output, 1)?; - Ok((output, state)) + fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor> { + let states = states.iter().map(|s| s.h.clone()).collect::<Vec<_>>(); + Tensor::cat(&states, 1) } } @@ -322,17 +329,8 @@ impl RNN for GRU { Ok(GRUState { h: next_h }) } - /// The input should have dimensions [batch_size, seq_len, features]. - fn seq_init(&self, input: &Tensor, in_state: &Self::State) -> Result<(Tensor, Self::State)> { - let (_b_size, seq_len, _features) = input.dims3()?; - let mut state = in_state.clone(); - let mut output: Vec<Tensor> = Vec::with_capacity(seq_len); - for seq_index in 0..seq_len { - let input = input.i((.., seq_index, ..))?; - state = self.step(&input, &state)?; - output.push(state.h.clone()); - } - let output = Tensor::cat(&output, 1)?; - Ok((output, state)) + fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor> { + let states = states.iter().map(|s| s.h.clone()).collect::<Vec<_>>(); + Tensor::cat(&states, 1) } } |