diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-31 09:43:10 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-31 08:43:10 +0100 |
commit | db598160876a841fcf372ae923e4c5328a47f653 (patch) | |
tree | 8d927dad45d515859d8e5b8a5f4c5822fafcc636 /candle-nn/src/lib.rs | |
parent | d210c71d77a6044c2a42c2e75487b6180e957158 (diff) | |
download | candle-db598160876a841fcf372ae923e4c5328a47f653.tar.gz candle-db598160876a841fcf372ae923e4c5328a47f653.tar.bz2 candle-db598160876a841fcf372ae923e4c5328a47f653.zip |
Add a GRU layer. (#688)
* Add a GRU layer.
* Fix the n gate computation.
Diffstat (limited to 'candle-nn/src/lib.rs')
-rw-r--r-- | candle-nn/src/lib.rs | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index e9552e83..48046081 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -25,7 +25,7 @@ pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; pub use linear::{linear, linear_no_bias, Linear}; pub use ops::Dropout; pub use optim::{AdamW, ParamsAdamW, SGD}; -pub use rnn::{lstm, LSTM, RNN}; +pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN}; pub use var_builder::VarBuilder; pub use var_map::VarMap; |