summaryrefslogtreecommitdiff
path: root/candle-nn/src/rnn.rs
Commit message (Collapse)AuthorAgeFilesLines
* Make the RNN configs accessible from the models. (#2541)Laurent Mazare2024-10-041-72/+103
|
* Add/lstm direction (#2455)Justin Sing2024-09-301-8/+25
| | | | | | | | | | | | | | | * add: direction for lstm layer * lint: remove unused Error import * refactor: remove unnecessary int assignment to Direction enum: * refactor: use &'static str type instead of String for direction_str: * Run cargofmt. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
* onnx: implement LSTM op (#2268)shua2024-08-191-0/+4
| | | use candle-nn LSTM
* update: LSTMState and GRUState fields to be public (#2384)Justin Sing2024-08-011-3/+3
|
* Relax the contiguous check for cuda kernels. (#2000)Laurent Mazare2024-04-031-1/+1
| | | | | | | | | * Relax the contiguous check for cuda kernels. * Ensure contiguity for RNNs. * Unrelated fix for segment anything. * Better error message + allow concatenating empty slices.
* Encodec model. (#1771)Laurent Mazare2024-02-271-1/+1
| | | | | | | | | | | | | | | * Encodec model. * Fixes. * Add the padding functions. * Get the LSTM bit to work. * Get the encodec model to generate some tokens (decoder only for now). * Minor tweak. * Minor tweak.
* More general seq forward functions for RNNs. (#1050)Laurent Mazare2023-10-071-27/+25
|
* Configurable layer idx for the lstm layer. (#962)Laurent Mazare2023-09-251-4/+12
|
* Add clone to various nn layers. (#910)Laurent Mazare2023-09-201-2/+2
|
* Fix the rnn tests for accelerate. (#704)Laurent Mazare2023-09-011-2/+4
|
* Add a GRU layer. (#688)Laurent Mazare2023-08-311-0/+142
| | | | | * Add a GRU layer. * Fix the n gate computation.
* Add a LSTM test. (#681)Laurent Mazare2023-08-301-1/+1
| | | | | * Add a LSTM test. * Clippy.
* Add tanh. (#675)Laurent Mazare2023-08-301-4/+2
| | | | | | | * Add tanh. * Use tanh in the lstm block. * Add a test for tanh forward and backward passes.
* Add some recurrent neural networks (#674)Laurent Mazare2023-08-301-0/+188
* Add the rnn module. * More LSTM. * Implement the RNN forward pass. * More forward pass for LSTM.