summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJustin Sing <32938975+singjc@users.noreply.github.com>2024-08-01 10:30:32 -0400
committerGitHub <noreply@github.com>2024-08-01 16:30:32 +0200
commit6991a37b94fdcfb6c1d69b7ac4b6d6b96654111d (patch)
tree5109004b6497562fc3b46693752e41fa11fe770e
parent9ca277a9d71d1919228a4b994750e3d811da6b0a (diff)
downloadcandle-6991a37b94fdcfb6c1d69b7ac4b6d6b96654111d.tar.gz
candle-6991a37b94fdcfb6c1d69b7ac4b6d6b96654111d.tar.bz2
candle-6991a37b94fdcfb6c1d69b7ac4b6d6b96654111d.zip
update: LSTMState and GRUState fields to be public (#2384)
-rw-r--r--candle-nn/src/rnn.rs6
1 files changed, 3 insertions, 3 deletions
diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs
index dbfa639b..30ad6ff5 100644
--- a/candle-nn/src/rnn.rs
+++ b/candle-nn/src/rnn.rs
@@ -50,8 +50,8 @@ pub trait RNN {
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Clone)]
pub struct LSTMState {
- h: Tensor,
- c: Tensor,
+ pub h: Tensor,
+ pub c: Tensor,
}
impl LSTMState {
@@ -205,7 +205,7 @@ impl RNN for LSTM {
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Clone)]
pub struct GRUState {
- h: Tensor,
+ pub h: Tensor,
}
impl GRUState {