summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-02-27 22:59:40 +0100
committerGitHub <noreply@github.com>2024-02-27 22:59:40 +0100
commit0c49e95dfb5f25a17340ff1b690a4b4f6cd0e2d1 (patch)
treec732811778ea6e15c558dcbe35153cd110eb5959 /candle-nn
parent205767f9ded3d531822d3702442a52b4a320f72e (diff)
downloadcandle-0c49e95dfb5f25a17340ff1b690a4b4f6cd0e2d1.tar.gz
candle-0c49e95dfb5f25a17340ff1b690a4b4f6cd0e2d1.tar.bz2
candle-0c49e95dfb5f25a17340ff1b690a4b4f6cd0e2d1.zip
Encodec model. (#1771)
* Encodec model. * Fixes. * Add the padding functions. * Get the LSTM bit to work. * Get the encodec model to generate some tokens (decoder only for now). * Minor tweak. * Minor tweak.
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/rnn.rs2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs
index 9f144cca..07795eda 100644
--- a/candle-nn/src/rnn.rs
+++ b/candle-nn/src/rnn.rs
@@ -197,7 +197,7 @@ impl RNN for LSTM {
fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor> {
let states = states.iter().map(|s| s.h.clone()).collect::<Vec<_>>();
- Tensor::cat(&states, 1)
+ Tensor::stack(&states, 1)
}
}