summaryrefslogtreecommitdiff
path: root/candle-nn/src
Commit message (Collapse)AuthorAgeFilesLines
* Simplify the one-hot implementation, support arbitrary rank. (#1514)Laurent Mazare2024-01-011-181/+38
| | | | | * Simplify the one-hot implementation, support arbitrary rank. * More cleanup.
* Add one-hot/cold encoding (#1489)Ryan Tate2024-01-012-0/+294
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | * add one-hot encoding * one_hot: improve error handling, use generic to_vecN::<D> Bails if the index value is equal to or greater than the depth value, which would result in an out-of-bounds error. A redundant check is added to ensure the index value does not exceed the length of the one-hot matrix size, which would also result in an out-of-bounds error. Bails if the index value is less than -1. If the index value is -1, then it ignores the setting of the on_value for the index value. Only values that are less than -1 are considered errors. * one-hot: use two generics, one_hot::<I, O>, for input and output data types Separating the input and output data types allows the input tensor indices to be a different data type than the output encoded tensor data type. For example, one_hot::<i64, u8>(...) will take an input tensor of i64 values and encode the output tensor using u8 values. The generic I::DTYPE must match the data type of the input indices, otherwise the method will bail. Additionally, this method adds an `allow_f64` option to enable the input indices data type to be f64 values. f64 values are disabled by default. TODO: indices data type and the generic I data type are currently not compile-time checked. * one_hot: remove input generic, use indices dtype matching This commit removes the to_f64() type cast and explicitly matches the DType from the input tensor. Currently, only U8, U32 and I64 is supported for input tensors. The match arms on the dtype is verbose. It would be nice to use a generic type with the WithDtype traitbound to pass to the to_vecN method and then return an inner value. Open to suggestions for better approaches here to reduce the match arm verbosity. * one_hot: use flat_map iterator over dims instead of nested for loop This commit replaces the nested for loops with an flat map iter over the dimensions of the input tensor. This commit also adds a test for a rank 3 input tensor. * one_hot: use mandatory on/off-values, remove const msgs This commit also updates doc tests, comments and test cases. * Small cleanups. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
* Do not implement Module for BatchNorm. (#1513)Laurent Mazare2024-01-011-13/+13
|
* Small tweaks to batch-norm. (#1505)Laurent Mazare2023-12-301-19/+16
|
* [Breaking] Add training to batchnorm with exponential moving average (#1504)nkoppel2023-12-301-50/+158
| | | | | | | | | | | | | | | | | * Add training to batchnorm with exponential moving average * Add more checks to batch norm * Resolve some review comments * Add with_momentum varients of `new` methods * Add check for range of momentum variable; update batch norm test * Run cargo fmt * Add back num_features parameter * Format; tiny simplification
* Merge pull request #1318 from huggingface/metal4Nicolas Patry2023-12-201-0/+41
|\ | | | | Starting to fix some tests.
| * Clippy pass.Nicolas Patry2023-12-181-3/+3
| |
| * Addressing a lot of comments.Nicolas Patry2023-12-151-1/+2
| |
| * Remove `unwrap()`.Nicolas Patry2023-12-151-2/+2
| |
| * Renamed all kernel names.Nicolas Patry2023-12-151-3/+3
| |
| * Fixing softmax.Nicolas Patry2023-12-151-1/+1
| |
| * Working with merging encoders and using fences.Nicolas Patry2023-12-141-2/+0
| |
| * Lots of updates including some stack of command buffers.nicolas2023-12-121-1/+3
| |
| * Starting to fix some tests.Nicolas Patry2023-11-301-0/+40
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Few fixes. Going back on remote metal-rs. Reusing a single buffer (for now) to speed things up. Adding some half kernels. All tests are panicking instead of random failure. Putting back f16 index select. Add erf. Working version for llama2-c. Fixes + cache compute_pipeline_state. BF16 metal fix. Remove some prints. new_owned -> new()..to_owned(). Better batched matmul. Metal operational. Reuse buffers on our own reference counts. Tmp gemm. Revert "Tmp gemm." This reverts commit c65f68e98814b65daa596696bda076a73303dd82. Interleave committing. Speeding up copies using blit. Fmt. Fmt. Remove the assert! Fmt all. Fixes after big rebase. Add softmax for half and bfloat + tests Fixing Llama example + accumulate softmax in float.
* | Fix a couple typos (#1451)Laurent Mazare2023-12-171-1/+1
| | | | | | | | | | * Mixtral quantized instruct. * Fix a couple typos.
* | Expose AdamW parameters (#1449)Dave Lage2023-12-161-0/+8
| | | | | | | | | | * Expose AdamW parameters * Use reference
* | Speedup ShardedSafeTensors to load Tensors with default hints (#1384)YiiSh2023-12-141-1/+7
| | | | | | | | | | | | | | | | | | * Speedup ShardedSafeTensors to load Tensors with default hints * Tweaks. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
* | Another prelu bugfix. (#1407)Laurent Mazare2023-12-061-1/+1
| |
* | Use the proper broadcasting for prelu. (#1406)Laurent Mazare2023-12-051-5/+16
| |
* | Add the prelu layer. (#1402)Laurent Mazare2023-12-033-4/+51
|/
* Add support to UL2 model family (#1300)Juarez Bochi2023-11-091-1/+0
| | | | | | | | | | | | | | | | | | | | | | | | | * Add support to UL2 model family * Update docs with UL2 * Create ActivationWithOptionalGating to avoid polluting activations * Also refactor quantized t5 * Remove useless conversion * Revert Activation::NewGelu name change * Remove useless return * Apply rustfmt and clippy recommendations * Reuse t5::ActivationWithOptionalGating in quantized version * (cosmetic change) use a match rather than ifs + avoid early returns. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
* Add weight and bias functions to LayerNorm (#1306)jwnz2023-11-091-0/+8
|
* Transposed conv1d in candle-nn. (#1252)Laurent Mazare2023-11-031-0/+94
|
* Add the swiglu activation from the chatglm PR. (#1246)Laurent Mazare2023-11-022-0/+7
|
* Add hard-sigmoid and hard-swish activations (#1244)jamjamjon2023-11-022-0/+9
| | | | | | | | | | | * Add hard-sigmoid and hard-swish activations * Update ops.rs * Use / rather than div. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
* Add support for the marian base model. (#1221)Laurent Mazare2023-10-301-0/+2
|
* Allow for different behavior between training and eval (#1213)Laurent Mazare2023-10-293-2/+43
| | | | | * Forward with training. * Do not use dropout on vgg evaluation.
* Add the relu2 and relu6 activations. (#1201)Laurent Mazare2023-10-271-0/+4
|
* Add fuse-conv-bn method for Conv2d (#1196)jamjamjon2023-10-272-0/+25
| | | | | | | * Add fuse-conv-bn method for Conv2d * no unwrap * run rustfmp and clippy
* Expose the fields from batch-norm. (#1176)Laurent Mazare2023-10-251-2/+12
|
* Add Binary Cross Entropy With Logit Loss to nn crate (#1157)Ogundepo Odunayo2023-10-231-0/+22
| | | | | | | | | | | | | * add bce with logit loss * add bce with logit loss * remove imports * fix tiny bug * add test documentation and refactor function * fix test cases and formatting
* Make func cloneable. (#1137)Laurent Mazare2023-10-202-6/+8
|
* Add the sequential layer. (#1136)Laurent Mazare2023-10-202-0/+64
|
* Experiment with resnet (#1128)Laurent Mazare2023-10-191-0/+9
| | | | | * Add some preliminary support for resnet. * Add an actual resnet example.
* feat: add pth varbuilder (#1108)OlivierDehaene2023-10-161-0/+41
|
* Only optimize float tensors. (#1069)Laurent Mazare2023-10-101-0/+5
|
* More general seq forward functions for RNNs. (#1050)Laurent Mazare2023-10-071-27/+25
|
* Use AsRef<str> for set_one. (#1033)Laurent Mazare2023-10-051-1/+1
|
* Bump the version to 0.3.0. (#1014)Laurent Mazare2023-10-011-20/+0
| | | | | * Bump the version to 0.3.0. * Changelog update.
* Use a silu activation in mistral. (#991)Laurent Mazare2023-09-291-0/+4
|
* Use the gelu-erf activation. (#969)Laurent Mazare2023-09-261-3/+1
|
* Configurable layer idx for the lstm layer. (#962)Laurent Mazare2023-09-251-4/+12
|
* Depreate the VarBuilder::from_safetensors function. (#951)Laurent Mazare2023-09-241-2/+6
|
* Self-contained safetensors for the multiprocess llama example. (#950)Laurent Mazare2023-09-241-31/+17
|
* Add the buffered safetensor wrapper. (#948)Laurent Mazare2023-09-231-0/+32
|
* Self-contained safetensor wrappers (#946)Laurent Mazare2023-09-231-1/+42
| | | | | * Self-contained safetensor wrappers. * Use the new safetensor container in varbuilders.
* Use yoke to provide a self-referential container for mmaped safetenso… (#939)Laurent Mazare2023-09-231-11/+5
| | | | | | | | | * Use yoke to provide a self-referential container for mmaped safetensor files. * Add the new self-owned type for safetensor files without removing the previous version. * Add routing. * Add an initializer for the case of multiple files.
* VarMap setter functions (#938)Laurent Mazare2023-09-231-0/+38
| | | | | * Add some setter helper functions for varmap. * Add more comments.
* Add clone to various nn layers. (#910)Laurent Mazare2023-09-207-11/+11
|
* Fix the leaky relu. (#898)Laurent Mazare2023-09-191-1/+2
|