summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-nn/src/rnn.rs52
1 files changed, 25 insertions, 27 deletions
diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs
index 10ba48f3..9f144cca 100644
--- a/candle-nn/src/rnn.rs
+++ b/candle-nn/src/rnn.rs
@@ -4,7 +4,7 @@ use candle::{DType, Device, IndexOp, Result, Tensor};
/// Trait for Recurrent Neural Networks.
#[allow(clippy::upper_case_acronyms)]
pub trait RNN {
- type State;
+ type State: Clone;
/// A zero state from which the recurrent network is usually initialized.
fn zero_state(&self, batch_dim: usize) -> Result<Self::State>;
@@ -18,7 +18,7 @@ pub trait RNN {
///
/// The input should have dimensions [batch_size, seq_len, features].
/// The initial state is the result of applying zero_state.
- fn seq(&self, input: &Tensor) -> Result<(Tensor, Self::State)> {
+ fn seq(&self, input: &Tensor) -> Result<Vec<Self::State>> {
let batch_dim = input.dim(0)?;
let state = self.zero_state(batch_dim)?;
self.seq_init(input, &state)
@@ -27,7 +27,23 @@ pub trait RNN {
/// Applies multiple steps of the recurrent network.
///
/// The input should have dimensions [batch_size, seq_len, features].
- fn seq_init(&self, input: &Tensor, state: &Self::State) -> Result<(Tensor, Self::State)>;
+ fn seq_init(&self, input: &Tensor, init_state: &Self::State) -> Result<Vec<Self::State>> {
+ let (_b_size, seq_len, _features) = input.dims3()?;
+ let mut output = Vec::with_capacity(seq_len);
+ for seq_index in 0..seq_len {
+ let input = input.i((.., seq_index, ..))?;
+ let state = if seq_index == 0 {
+ self.step(&input, init_state)?
+ } else {
+ self.step(&input, &output[seq_index - 1])?
+ };
+ output.push(state);
+ }
+ Ok(output)
+ }
+
+ /// Converts a sequence of state to a tensor.
+ fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor>;
}
/// The state for a LSTM network, this contains two tensors.
@@ -179,18 +195,9 @@ impl RNN for LSTM {
})
}
- /// The input should have dimensions [batch_size, seq_len, features].
- fn seq_init(&self, input: &Tensor, in_state: &Self::State) -> Result<(Tensor, Self::State)> {
- let (_b_size, seq_len, _features) = input.dims3()?;
- let mut state = in_state.clone();
- let mut output: Vec<Tensor> = Vec::with_capacity(seq_len);
- for seq_index in 0..seq_len {
- let input = input.i((.., seq_index, ..))?;
- state = self.step(&input, &state)?;
- output.push(state.h.clone());
- }
- let output = Tensor::cat(&output, 1)?;
- Ok((output, state))
+ fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor> {
+ let states = states.iter().map(|s| s.h.clone()).collect::<Vec<_>>();
+ Tensor::cat(&states, 1)
}
}
@@ -322,17 +329,8 @@ impl RNN for GRU {
Ok(GRUState { h: next_h })
}
- /// The input should have dimensions [batch_size, seq_len, features].
- fn seq_init(&self, input: &Tensor, in_state: &Self::State) -> Result<(Tensor, Self::State)> {
- let (_b_size, seq_len, _features) = input.dims3()?;
- let mut state = in_state.clone();
- let mut output: Vec<Tensor> = Vec::with_capacity(seq_len);
- for seq_index in 0..seq_len {
- let input = input.i((.., seq_index, ..))?;
- state = self.step(&input, &state)?;
- output.push(state.h.clone());
- }
- let output = Tensor::cat(&output, 1)?;
- Ok((output, state))
+ fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor> {
+ let states = states.iter().map(|s| s.h.clone()).collect::<Vec<_>>();
+ Tensor::cat(&states, 1)
}
}