diff options
author | shua <gpg@isthisa.email> | 2024-08-19 09:06:17 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-19 09:06:17 +0200 |
commit | 31a1075f4b4799a4922fa0f617ee982baa5baa81 (patch) | |
tree | bf292f3c3162fd0afdc41a01c68491358a69280a /candle-nn | |
parent | 236b29ff1555db82fdb78c1be8741c0ac37859d1 (diff) | |
download | candle-31a1075f4b4799a4922fa0f617ee982baa5baa81.tar.gz candle-31a1075f4b4799a4922fa0f617ee982baa5baa81.tar.bz2 candle-31a1075f4b4799a4922fa0f617ee982baa5baa81.zip |
onnx: implement LSTM op (#2268)
use candle-nn LSTM
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/rnn.rs | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index 30ad6ff5..5934af85 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -55,6 +55,10 @@ pub struct LSTMState { } impl LSTMState { + pub fn new(h: Tensor, c: Tensor) -> Self { + LSTMState { h, c } + } + /// The hidden state vector, which is also the output of the LSTM. pub fn h(&self) -> &Tensor { &self.h |