summaryrefslogtreecommitdiff
path: root/candle-nn/src
Commit message (Collapse)AuthorAgeFilesLines
...
* Add the AdamW optimizer. (#307)Laurent Mazare2023-08-022-3/+115
| | | | | * Add the AdamW optimizer. * Add some AdamW test validated against PyTorch.
* Use index-select for the embeddings as it supports backprop. (#298)Laurent Mazare2023-08-011-1/+1
|
* Llama more training (#297)Laurent Mazare2023-08-016-19/+195
| | | | | | | | | | | * Rework the var-builder to handle initializations. * Add some helper functions for layer creation. * Improve the layer initializations. * Get initialized variables. * Precompute the rot embeddings when training lamas.
* Add some batcher variants that handle errors. (#294)Laurent Mazare2023-08-011-0/+75
|
* Add the batcher. (#293)Laurent Mazare2023-08-012-0/+97
|
* Add the cross-entropy loss. (#287)Laurent Mazare2023-07-311-0/+17
|
* Make the nll op closer to the pytorch version + add a test. (#286)Laurent Mazare2023-07-311-2/+22
|
* Improve the mnist training example. (#276)Laurent Mazare2023-07-292-4/+37
| | | | | | | * Improve the mnist training example. * Add some initialization routine that can be used for nn. * Proper initialization in the mnist example.
* More mnist training. (#275)Laurent Mazare2023-07-291-0/+1
|
* Softmax numerical stability. (#267)Laurent Mazare2023-07-281-0/+24
| | | | | * Softmax numerical stability. * Fix the flash-attn test.
* Added comment about offsets.Nicolas Patry2023-07-271-0/+3
|
* Fixing slice errors + comments.Nicolas Patry2023-07-271-3/+22
|
* Removing inner dependency on safetensors.Nicolas Patry2023-07-271-4/+6
|
* TP sharding v2Nicolas Patry2023-07-271-5/+53
|
* Move some shared functions to the nn module. (#221)Laurent Mazare2023-07-223-0/+20
|
* Rename the .r functions to .dims so as to be a bit more explicit. (#220)Laurent Mazare2023-07-222-2/+2
|
* [Proposal] Remove SafeTensor wrapper (allows finer control for users).Nicolas Patry2023-07-191-2/+6
|
* Vision dataset (#179)Laurent Mazare2023-07-164-0/+140
| | | | | * Add some readers for the mnist dataset. * Import the cifar and mnist dataset.
* Add backtrace information to errors where relevant. (#166)Laurent Mazare2023-07-141-7/+17
| | | | | | | * Add backtrace information to errors where relevant. * More backtrace information. * Add to the FAQ.
* Simplify the parameters used by sum and sum_keepdim. (#165)Laurent Mazare2023-07-141-2/+2
|
* Use the same default as pytorch for sum. (#164)Laurent Mazare2023-07-131-2/+2
|
* Add the gradient for reduce-sum. (#162)Laurent Mazare2023-07-131-1/+1
| | | | | | | | | * Add the gradient for reduce-sum. * And add the gradient for the broadcast ops. * Add some backprop tests. * Add some linear regression example.
* Add the SGD optimizer (#160)Laurent Mazare2023-07-132-0/+49
| | | | | | | | | * Add the nn::optim and some conversion traits. * Add the backward_step function for SGD. * Get the SGD optimizer to work and add a test. * Make the test slighly simpler.
* Add some documentation and test to the linear layer. (#151)Laurent Mazare2023-07-124-0/+51
| | | | | | | * Add some documentation and test to the linear layer. * Layer norm doc. * Minor tweaks.
* Cleanup the main crate error and add a couple dedicated ones (#142)Laurent Mazare2023-07-121-2/+3
| | | | | | | | | * Cosmetic cleanups to the error enum. * More error cleanup. * Proper error handling rather than panicing. * Add some conv1d dedicated error.
* Allow for lazy loading of npz files, use it in llama to reduce memory usage ↵Laurent Mazare2023-07-111-2/+27
| | | | in the cpu version. (#141)
* Resurrect the llama npy support. (#140)Laurent Mazare2023-07-111-28/+55
|
* Sketch the tensor initialization module. (#134)Laurent Mazare2023-07-112-6/+116
|
* VarBuilder path creation (#131)Laurent Mazare2023-07-101-19/+84
| | | | | | | * Use a struct for the safetensor+routing. * Group the path and the var-builder together. * Fix for the empty path case.
* Move the var-builder in a central place. (#130)Laurent Mazare2023-07-102-0/+61
|
* Move the conv1d layer to candle_nn. (#117)Laurent Mazare2023-07-102-0/+51
|
* [nn] Move the Embedding and Activation parts. (#116)Laurent Mazare2023-07-103-0/+53
| | | | | * Share the Embedding and Activation parts. * Tweak some activations.
* Sketch the candle-nn crate. (#115)Laurent Mazare2023-07-103-0/+64
* Sketch the candle-nn crate. * Tweak the cuda dependencies. * More cuda tweaks.