summaryrefslogtreecommitdiff
path: root/candle-nn
Commit message (Collapse)AuthorAgeFilesLines
...
* feat: add silu activation function (#1706)OlivierDehaene2024-02-142-4/+3
| | | | | | | | | * feat: add silu activation function * use silu/arg in grad * update candle-nn * use node
* Detach the tensors on batch-norm eval. (#1702)Laurent Mazare2024-02-131-2/+12
| | | | | | | | | | | | | * Detach the tensors on batch-norm eval. * Fix pyo3 bindings. * Black tweak. * Formatting. * Also update the pyo3-onnx formatting. * Apply black.
* Fix clippy lints for 1.76. (#1682)Laurent Mazare2024-02-081-1/+1
|
* Enhance pickle to retrieve state_dict with a given key (#1671)Dilshod Tadjibaev2024-02-061-1/+1
|
* Add `VarBuilder::from_backend` (#1670)Daniƫl de Kok2024-02-061-8/+17
| | | | | | | | | `candle-nn` already exposes a trait to define custom backends. However, it's not possible to actually construct a `VarBuilder` with a custom backend because the constructor is not exposed. This change makes the constructor public and renames it from `new` to `from_backend` to avoid that it is seen as the primary constructor (which could be confusing to users).
* Quantized GGUF style (#1523)Nicolas Patry2024-01-171-1/+4
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | * Metal quantized modifications proposal. - Add a device param, wherever needed. - Create new QMetal storage thing that implements QuantizedType. - Update everywhere needed. Fix Python. Fixing examples. Fix: fmt + clippy + stub. Moving everything around. Only missing the actual implems. Fixing everything + adding dequantized kernels. More work. Fixing matmul. Fmt + Clippy Some clippy fixes. Working state. Q2K Metal -> Bugged (also present in GGML). Q4K CPU -> Bugged (present previously, new test catch it). Q5K CPU -> Bugged (present previously). Q8_1 Both -> Never really implemented it seems Q8K metal -> Never implemented in metal Fixing Q2K bug (present in ggml). * Cleanup. * Fix the rebase. * Removing the fences speeds everything up and *is* correct this time... * Cleanup the fence. * After rebase. * Bad code removal. * Rebase after phi2 merge + fix replit default to CPU. * Making the CI happy. * More happy tests. --------- Co-authored-by: Nicolas Patry <nicolas@Nicolass-MacBook-Pro.local>
* Update the Phi model to use the updated architecture. (#1580)Laurent Mazare2024-01-131-0/+1
| | | | | | | | | | | | | | | * Update the Phi model to use the updated architecture. * Add more of the phi model. * Repeat KV + caching. * Apply the rotary embeddings. * Add support for the new phi model in the phi example. * Fix a couple glitches. * Fix a couple more glitches.
* Simplifying our internal cargo dependencies. (#1529)Nicolas Patry2024-01-071-2/+2
|
* Simplify the one-hot implementation, support arbitrary rank. (#1514)Laurent Mazare2024-01-011-181/+38
| | | | | * Simplify the one-hot implementation, support arbitrary rank. * More cleanup.
* Add one-hot/cold encoding (#1489)Ryan Tate2024-01-013-0/+414
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | * add one-hot encoding * one_hot: improve error handling, use generic to_vecN::<D> Bails if the index value is equal to or greater than the depth value, which would result in an out-of-bounds error. A redundant check is added to ensure the index value does not exceed the length of the one-hot matrix size, which would also result in an out-of-bounds error. Bails if the index value is less than -1. If the index value is -1, then it ignores the setting of the on_value for the index value. Only values that are less than -1 are considered errors. * one-hot: use two generics, one_hot::<I, O>, for input and output data types Separating the input and output data types allows the input tensor indices to be a different data type than the output encoded tensor data type. For example, one_hot::<i64, u8>(...) will take an input tensor of i64 values and encode the output tensor using u8 values. The generic I::DTYPE must match the data type of the input indices, otherwise the method will bail. Additionally, this method adds an `allow_f64` option to enable the input indices data type to be f64 values. f64 values are disabled by default. TODO: indices data type and the generic I data type are currently not compile-time checked. * one_hot: remove input generic, use indices dtype matching This commit removes the to_f64() type cast and explicitly matches the DType from the input tensor. Currently, only U8, U32 and I64 is supported for input tensors. The match arms on the dtype is verbose. It would be nice to use a generic type with the WithDtype traitbound to pass to the to_vecN method and then return an inner value. Open to suggestions for better approaches here to reduce the match arm verbosity. * one_hot: use flat_map iterator over dims instead of nested for loop This commit replaces the nested for loops with an flat map iter over the dimensions of the input tensor. This commit also adds a test for a rank 3 input tensor. * one_hot: use mandatory on/off-values, remove const msgs This commit also updates doc tests, comments and test cases. * Small cleanups. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
* Do not implement Module for BatchNorm. (#1513)Laurent Mazare2024-01-012-15/+15
|
* Small tweaks to batch-norm. (#1505)Laurent Mazare2023-12-301-19/+16
|
* [Breaking] Add training to batchnorm with exponential moving average (#1504)nkoppel2023-12-302-50/+169
| | | | | | | | | | | | | | | | | * Add training to batchnorm with exponential moving average * Add more checks to batch norm * Resolve some review comments * Add with_momentum varients of `new` methods * Add check for range of momentum variable; update batch norm test * Run cargo fmt * Add back num_features parameter * Format; tiny simplification
* Bump the crate version to 0.3.3. (#1490)Laurent Mazare2023-12-281-1/+1
|
* Merge pull request #1318 from huggingface/metal4Nicolas Patry2023-12-202-0/+44
|\ | | | | Starting to fix some tests.
| * Clippy pass.Nicolas Patry2023-12-181-3/+3
| |
| * Addressing a lot of comments.Nicolas Patry2023-12-151-1/+2
| |
| * Remove `unwrap()`.Nicolas Patry2023-12-151-2/+2
| |
| * Renamed all kernel names.Nicolas Patry2023-12-151-3/+3
| |
| * Fixing softmax.Nicolas Patry2023-12-151-1/+1
| |
| * Working with merging encoders and using fences.Nicolas Patry2023-12-141-2/+0
| |
| * Lots of updates including some stack of command buffers.nicolas2023-12-122-2/+5
| |
| * Starting to fix some tests.Nicolas Patry2023-11-302-0/+42
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Few fixes. Going back on remote metal-rs. Reusing a single buffer (for now) to speed things up. Adding some half kernels. All tests are panicking instead of random failure. Putting back f16 index select. Add erf. Working version for llama2-c. Fixes + cache compute_pipeline_state. BF16 metal fix. Remove some prints. new_owned -> new()..to_owned(). Better batched matmul. Metal operational. Reuse buffers on our own reference counts. Tmp gemm. Revert "Tmp gemm." This reverts commit c65f68e98814b65daa596696bda076a73303dd82. Interleave committing. Speeding up copies using blit. Fmt. Fmt. Remove the assert! Fmt all. Fixes after big rebase. Add softmax for half and bfloat + tests Fixing Llama example + accumulate softmax in float.
* | Bump the crate version to 0.3.2. (#1452)Laurent Mazare2023-12-171-1/+1
| |
* | Fix a couple typos (#1451)Laurent Mazare2023-12-171-1/+1
| | | | | | | | | | * Mixtral quantized instruct. * Fix a couple typos.
* | Expose AdamW parameters (#1449)Dave Lage2023-12-161-0/+8
| | | | | | | | | | * Expose AdamW parameters * Use reference
* | Speedup ShardedSafeTensors to load Tensors with default hints (#1384)YiiSh2023-12-141-1/+7
| | | | | | | | | | | | | | | | | | * Speedup ShardedSafeTensors to load Tensors with default hints * Tweaks. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
* | Another prelu bugfix. (#1407)Laurent Mazare2023-12-061-1/+1
| |
* | Use the proper broadcasting for prelu. (#1406)Laurent Mazare2023-12-051-5/+16
| |
* | Add the prelu layer. (#1402)Laurent Mazare2023-12-033-4/+51
|/
* Implement the module trait directly for QMatMul. (#1372)Laurent Mazare2023-11-251-1/+1
|
* Update for 0.3.1. (#1324)Laurent Mazare2023-11-111-1/+1
|
* Add support to UL2 model family (#1300)Juarez Bochi2023-11-091-1/+0
| | | | | | | | | | | | | | | | | | | | | | | | | * Add support to UL2 model family * Update docs with UL2 * Create ActivationWithOptionalGating to avoid polluting activations * Also refactor quantized t5 * Remove useless conversion * Revert Activation::NewGelu name change * Remove useless return * Apply rustfmt and clippy recommendations * Reuse t5::ActivationWithOptionalGating in quantized version * (cosmetic change) use a match rather than ifs + avoid early returns. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
* Add weight and bias functions to LayerNorm (#1306)jwnz2023-11-091-0/+8
|
* Transposed conv1d in candle-nn. (#1252)Laurent Mazare2023-11-031-0/+94
|
* Add the swiglu activation from the chatglm PR. (#1246)Laurent Mazare2023-11-022-0/+7
|
* Add hard-sigmoid and hard-swish activations (#1244)jamjamjon2023-11-022-0/+9
| | | | | | | | | | | * Add hard-sigmoid and hard-swish activations * Update ops.rs * Use / rather than div. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
* Add support for the marian base model. (#1221)Laurent Mazare2023-10-301-0/+2
|
* Allow for different behavior between training and eval (#1213)Laurent Mazare2023-10-293-2/+43
| | | | | * Forward with training. * Do not use dropout on vgg evaluation.
* Add the relu2 and relu6 activations. (#1201)Laurent Mazare2023-10-271-0/+4
|
* Add fuse-conv-bn method for Conv2d (#1196)jamjamjon2023-10-272-0/+25
| | | | | | | * Add fuse-conv-bn method for Conv2d * no unwrap * run rustfmp and clippy
* Expose the fields from batch-norm. (#1176)Laurent Mazare2023-10-251-2/+12
|
* Add Binary Cross Entropy With Logit Loss to nn crate (#1157)Ogundepo Odunayo2023-10-232-0/+69
| | | | | | | | | | | | | * add bce with logit loss * add bce with logit loss * remove imports * fix tiny bug * add test documentation and refactor function * fix test cases and formatting
* Make func cloneable. (#1137)Laurent Mazare2023-10-202-6/+8
|
* Add the sequential layer. (#1136)Laurent Mazare2023-10-202-0/+64
|
* Experiment with resnet (#1128)Laurent Mazare2023-10-191-0/+9
| | | | | * Add some preliminary support for resnet. * Add an actual resnet example.
* feat: add pth varbuilder (#1108)OlivierDehaene2023-10-161-0/+41
|
* Add a matvec cpu benchmark. (#1076)Laurent Mazare2023-10-121-3/+22
|
* Convmixer (#1073)Laurent Mazare2023-10-111-2/+2
| | | | | | | | | | | * Only optimize float tensors. * Use full tensors for zeros and ones. * Add a benchmark for the matmul slowness. * Add the convmixer model. * Proper adaptive pooling.
* Only optimize float tensors. (#1069)Laurent Mazare2023-10-101-0/+5
|