summaryrefslogtreecommitdiff
path: root/candle-nn
Commit message (Collapse)AuthorAgeFilesLines
* Sync upstream MLX sdpa vector kernels with mask (#2718)HEADmainEric Buehler2025-01-161-21/+74
| | | | | | | * Sync upstream mlx sdpa vector kernels with mask * Dispatch to the 2pass kernel * Format
* ModernBERT model (#2713)Jani Monoses2025-01-132-1/+12
| | | | | | | | | | | * layer_norm_no_bias * Modernbert model. * Format + cleanup error. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
* Clippy fixes for 1.84. (#2710)Laurent Mazare2025-01-101-2/+2
|
* Lint fixes introduced with Rust 1.83 (#2646)Anubhab Bandyopadhyay2024-11-282-10/+10
| | | | | | | | | | | * Fixes for lint errors introduced with Rust 1.83 * rustfmt * Fix more lints. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
* Provide a method to allow PTH files with state maps to be loaded. (#2639)zachcp2024-11-261-1/+11
| | | | | | | * Provide a method to allow PTH files iwth state maps to be loaded. * add a line to the doc * String-. &str
* Update docs (#2553)zachcp2024-11-119-0/+34
| | | | | * add module docs for candle-core * doc each of the candle-nn modules and add the links to the doc page
* Add some fast Metal MLX SDPA kernels (#2584)Eric Buehler2024-11-052-0/+396
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | * Add some fast Metal MLX SDPA kernels (#32) * Sketch the sdpa kernel * Add full sdpa kernel, * Add test * Add vectorized kernel for decoding * Update tests * Add some docs * Fix sdpa_vector names * Add softcapping for vectorized sdpa * Add softcapping for full sdpa * Add support for head dim 32, 96, 256 * Add support for head dim 32, 96, 256 * Update docs * Add update notice * Clippy and format * Conditional compilation for bf16 * Use it in quantized llama * Some review comments * Use set_params! * Remove unused * Remove feature * Fix metal sdpa for v stride * Remove comma * Add the dim method to layout and shape. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
* Improved launch config for layer-norm/rms-norm. (#2591)Laurent Mazare2024-11-042-4/+66
| | | | | * Improved launch config for layer-norm/rms-norm. * Add more testing for the fused layer/rms norm kernels.
* Make the RNN configs accessible from the models. (#2541)Laurent Mazare2024-10-041-72/+103
|
* Add/lstm direction (#2455)Justin Sing2024-09-301-8/+25
| | | | | | | | | | | | | | | * add: direction for lstm layer * lint: remove unused Error import * refactor: remove unnecessary int assignment to Direction enum: * refactor: use &'static str type instead of String for direction_str: * Run cargofmt. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
* Add Pixtral. (#2521)Laurent Mazare2024-09-301-14/+22
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | * Add Pixtral. * More pixtral vision encoder. * Sketch a pixtral example. * Sketch a pixtral example. * Better image loading. * Support loading images embedded in safetensor files. * Clippy fixes. * Add the llava multimodal adapter. * Add more of the llava bits. * Add the pixtral config. * More pixtral inference. * Add the text generation bits. * Get the example to work. * Bugfix. * Run some bits of the model in f32. * Blessed version :) * Better rope frequency computations. * README update.
* Add a RotatingKVCache. (#2493)Laurent Mazare2024-09-232-1/+333
| | | | | | | | | | | | | | | | | | | | | | | | | | | * Add a RotatingKVCache. * Add some KvCache tests. * Test the reset too. * More kv-cache testing. * More tests for the rotating kv-cache. * Improve the api for the rotating cache so that the whole src tensor gets returned when it's overlarge. * Handle contiguity + bugfix + use in mimi. * Add a way to test the mimi streaming mode. * Mimi streaming fixes. * More rotating kv-cache. * Fix the attn mask generation. * Handle the abs case. * Add some tests for the generated mask.
* onnx: implement LSTM op (#2268)shua2024-08-191-0/+4
| | | use candle-nn LSTM
* update: LSTMState and GRUState fields to be public (#2384)Justin Sing2024-08-011-3/+3
|
* Add support for Llama 3.1 (#2359)Eric Buehler2024-07-262-3/+4
| | | | | | | | | | | | | | | | | * Add Llama 3.1 rope * Clippy * Format * Clippy * Add support for multiple eos tokens: * Untagged either * Remove either dep and fix settings.json * Make the max positional embeddings configurable
* Depth Anything v2 (#2279)Jeroen Vlek2024-06-241-1/+22
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | * define structs * construct ResidualConvUnit * forward() for ResidualConvUnit * implement FeatureFusionBlock * implement Scratch * implement DPTHead * add identity module * implement forward for DTPHead * add get_intermediate_layers to DinoVisionTransformer * implement DepthAnythingV2 * some minor tweaks * fix compile errors * fix var builder prefixes * setup initial example * use fixed patch size of 37 (518 / 14) * debugged until output * print min and max values * add some dynamism to the output location * scale input image * extract prep function * extract output path function * normalize image with magic mean and std * add spectral coloring * squeeze in the right place * make enterpolation optional * use bail instead of panic * omit unnecessary Shape call * remove empty curly braces * use bail instead of assert * use vb and pp * remove closures * extract config object * Apply rustfmt. * Fix some clippy lints. * More lints. * Use the array methods. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
* add where_cond f32 for metal (#2236)Lionel Touati2024-06-021-1/+1
|
* Enable the new layer-norm. (#2213)Laurent Mazare2024-05-242-5/+19
| | | | | * Enable the new layer-norm. * Shape fixes.
* Add the layernorm specialized op. (#2212)Laurent Mazare2024-05-242-5/+280
| | | | | | | | | | | | | | | * Add the layernorm cuda kernels. * Dedicated layer norm op. * Add the slower variant. * Plug the cuda implementation. * Add the metal variant. * Add a dedicated test. * Bugfix.
* Simplify the KvCache api. (#2207)Laurent Mazare2024-05-231-36/+53
|
* Add a couple kv-cache helper functions. (#2206)Laurent Mazare2024-05-231-0/+29
|
* Add a slice_set op. (#2193)Laurent Mazare2024-05-182-0/+102
| | | | | | | | | | | | | | | * Add a slice_set op. * Add some testing. * Add the dedicated kv-cache module. * Derive debug and clone. * Expose more kv-cache functions. * Return the current data when appending. * Use the new cache in the quantized phi3 model.
* Fix VarBuilder::from_slice_safetensors (#2180)Harry Stern2024-05-121-4/+30
| | | | | Also implement SimpleBackend for SliceSafetensors Signed-off-by: Harry Stern <harry@harrystern.net>
* Add SliceSafetensors. (#2179)Laurent Mazare2024-05-111-0/+6
| | | | | * Add SlicedSafetensors. * And add some testing.
* Bump the version number to 0.5.1. (#2155)Laurent Mazare2024-05-031-1/+1
| | | | | | | * Bump the version number to 0.5.1. * Fix clippy lints for 1.78. * More clippy fixes.
* Bug Fix: When converting a tensor to a variable, clone if the tensor is ↵Jeffrey Dallatezza2024-04-291-2/+44
| | | | | | | | | | | already a variable. (#2124) * When converting a tensor to a variable, clone if the tensor is already a variable. * Add a test to ensure training a batch norm works with VarMaps --------- Co-authored-by: Jeffrey Dallatezza <jeffreydallatezza@Jeffreys-Laptop.local>
* Fix sigmoid gradient calculation and move sigmoid into a specialized op (#2114)MilkFather2024-04-292-2/+197
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | * add sigmoid op * small fix * add as a method on `Tensor` * implement gradient calculation for sigmoid * add sigmoid tests * we should have a specialized op for this * fix clippy * fix clippy 2 * Revert all previous commits in favor of a `CustomOp` based solution * use `CustomOp1` implementation * fix rustfmt * experimental add metal impl * add cuda kernel impl * fix fmt * Add a test + reduce some cuda duplication. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
* Apply the cast before the scaling. (#2135)Laurent Mazare2024-04-281-1/+1
|
* Use the faster rms-norm kernel for llama. (#2107)Laurent Mazare2024-04-221-4/+13
| | | | | * Use the faster rms-norm kernel for llama. * Use the fast variant by default.
* Add a convenient way to rename tensors accessed through a varbuilder. (#2052)Laurent Mazare2024-04-131-0/+93
|
* Add the rope THD kernel. (#2014)Laurent Mazare2024-04-052-0/+262
| | | | | | | | | * Add the rope THD kernel. * Cuda kernel for rope-thd. * Add the metal kernels. * Add a dedicated test.
* Relax the contiguous check for cuda kernels. (#2000)Laurent Mazare2024-04-031-1/+1
| | | | | | | | | * Relax the contiguous check for cuda kernels. * Ensure contiguity for RNNs. * Unrelated fix for segment anything. * Better error message + allow concatenating empty slices.
* Add benchmarks for the candle-nn package (#1995)Thomas Santerre2024-04-035-0/+175
| | | | | | | * add benchmarks for the candle-nn package * uncomment test * format
* Add fn 'get_with_hints_dtype' in VarBuilder (#1877) (#1897)yinqiwen2024-04-011-4/+15
| | | * quantized models(awq/squeezellm/...) have multiple data type tensors, use 'get_with_hints_dtype' to load tensors with given dtype
* Fix detail in new RoPE implementation (#1935)Hugo Abonizio2024-03-251-1/+1
|
* Contiguous variant of the rope kernel. (#1929)Laurent Mazare2024-03-252-2/+282
| | | | | | | * Contiguous variant of the rope kernel. * Add the cuda kernel. * Metal kernel.
* Fast kernels for rotary embeddings. (#1928)Laurent Mazare2024-03-244-0/+277
| | | | | | | | | | | | | | | | | | | * Fast kernels for rotary embeddings. * Add a test for the fast CPU kernel. * Rope cuda bindings. * Cuda kernel. * Metal kernel (part 1). * Cuda kernels. * Finish the metal kernel. * Use the new kernels in the quantized example. * Fix warning.
* RmsNorm kernel for metal. (#1895)Laurent Mazare2024-03-211-1/+46
| | | | | | | | | * RmsNorm kernel for metal. * Wrapper for the metal kernel. * Get the ops to actually work. * Fix, get the tests to pass.
* Custom op for RmsNorm (#1890)Laurent Mazare2024-03-212-8/+197
| | | | | | | | | | | | | * Trying out a custom RmsNorm cuda kernel. * CPU implementation for rms-norm. * Cuda wrappers. * Add some validation. * Add some testing. * More testing.
* Optimize the cat operation on contiguous tensors (#1855)Laurent Mazare2024-03-171-0/+19
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | * Add a specialized kernel for copy2d. * Move the cat operations. * Avoid transpositions in cat. * Bugfix. * Bugfix for the cuda kernel. * Add a benchmark. * Add more testing. * Test fix. * Faster kernel. * Add the missing kernel. * Tweak the test. * Add a metal kernel. * Fix for the metal kernel. * Get the tests to pass on metal. * Also use this opportunity to fix the metal kernel for ELU. * Add some bf16 kernels. * Clippy fixes.
* add clone to candle dropout (#1814)Kirpal Grewal2024-03-081-1/+1
|
* Improve metal buffer usage (#1807)ivarflakstad2024-03-071-1/+2
| | | | | | | | | | | | | | | | | | | | * Improve metal buffer usage * Clone cpu storage when loading to reduce wait_until_complete calls * Use powers of two for buffer sizes so reuse is more likely. * Select best available buffer by size. * Add count to MetalStorage -> can use buffer with different size Co-authored-by: Chris Fleetwood <christopher.fleetwood@huggingface.co> * Simplify new buffer creation without blit copy. Revert &[] -> Vec * Add documentation on newBufferWithBytes safety / synchronization * Drop unused buffers after command buffer is done syncing. --------- Co-authored-by: Chris Fleetwood <christopher.fleetwood@huggingface.co>
* Add the StarCoder2 model. (#1779)Laurent Mazare2024-02-281-0/+4
| | | | | | | * Add the StarCoder2 model. * Add the example code and get things to work. * And also tweak the readme.
* Encodec model. (#1771)Laurent Mazare2024-02-271-1/+1
| | | | | | | | | | | | | | | * Encodec model. * Fixes. * Add the padding functions. * Get the LSTM bit to work. * Get the encodec model to generate some tokens (decoder only for now). * Minor tweak. * Minor tweak.
* Tweak the VarMap set type. (#1758)Laurent Mazare2024-02-252-2/+39
|
* Support for attention bias in gemma + refactor things a bit. (#1744)Laurent Mazare2024-02-222-6/+19
| | | | | * Support for attention bias in gemma + refactor things a bit. * Fix the cuda tests.
* Bugfix for applying the bias in conv1d-transpose. (#1732)Laurent Mazare2024-02-181-1/+1
|
* Support for groups in conv-transpose1d. (#1731)Laurent Mazare2024-02-181-3/+13
| | | | | * Groups support in conv-transpose-1d. * Remove dangling file.
* Expose the weights and biases in transposed convolutions. (#1727)Laurent Mazare2024-02-181-0/+16
|
* Expose more conv1d functions/structs. (#1726)Laurent Mazare2024-02-172-2/+19
|