summaryrefslogtreecommitdiff
path: root/candle-nn/src/lib.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-31 09:43:10 +0200
committerGitHub <noreply@github.com>2023-08-31 08:43:10 +0100
commitdb598160876a841fcf372ae923e4c5328a47f653 (patch)
tree8d927dad45d515859d8e5b8a5f4c5822fafcc636 /candle-nn/src/lib.rs
parentd210c71d77a6044c2a42c2e75487b6180e957158 (diff)
downloadcandle-db598160876a841fcf372ae923e4c5328a47f653.tar.gz
candle-db598160876a841fcf372ae923e4c5328a47f653.tar.bz2
candle-db598160876a841fcf372ae923e4c5328a47f653.zip
Add a GRU layer. (#688)
* Add a GRU layer. * Fix the n gate computation.
Diffstat (limited to 'candle-nn/src/lib.rs')
-rw-r--r--candle-nn/src/lib.rs2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index e9552e83..48046081 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -25,7 +25,7 @@ pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
pub use linear::{linear, linear_no_bias, Linear};
pub use ops::Dropout;
pub use optim::{AdamW, ParamsAdamW, SGD};
-pub use rnn::{lstm, LSTM, RNN};
+pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
pub use var_builder::VarBuilder;
pub use var_map::VarMap;