summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
Commit message (Collapse)AuthorAgeFilesLines
* Sync upstream MLX sdpa vector kernels with mask (#2718)HEADmainEric Buehler2025-01-162-28/+412
| | | | | | | * Sync upstream mlx sdpa vector kernels with mask * Dispatch to the 2pass kernel * Format
* Bump the caret version to 0.8.2. (#2703)Laurent Mazare2025-01-071-1/+1
|
* Bump the crate version to 0.8.1. (#2662)Laurent Mazare2024-12-071-1/+1
|
* add scatter add (#2656)zachcp2024-12-011-0/+1
|
* add u32 - U32 gather (#2653)zachcp2024-11-301-79/+80
|
* Lint fixes introduced with Rust 1.83 (#2646)Anubhab Bandyopadhyay2024-11-282-17/+20
| | | | | | | | | | | * Fixes for lint errors introduced with Rust 1.83 * rustfmt * Fix more lints. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
* Add some missing index-select metal kernels. (#2613)Laurent Mazare2024-11-121-0/+4
| | | | | * Add some missing index-select metal kernels. * Make some matrix contiguous pre-matmul.
* Bump the crate version to 0.8.0. (#2612)Laurent Mazare2024-11-121-1/+1
|
* Add some fast Metal MLX SDPA kernels (#2584)Eric Buehler2024-11-052-1/+1579
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | * Add some fast Metal MLX SDPA kernels (#32) * Sketch the sdpa kernel * Add full sdpa kernel, * Add test * Add vectorized kernel for decoding * Update tests * Add some docs * Fix sdpa_vector names * Add softcapping for vectorized sdpa * Add softcapping for full sdpa * Add support for head dim 32, 96, 256 * Add support for head dim 32, 96, 256 * Update docs * Add update notice * Clippy and format * Conditional compilation for bf16 * Use it in quantized llama * Some review comments * Use set_params! * Remove unused * Remove feature * Fix metal sdpa for v stride * Remove comma * Add the dim method to layout and shape. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
* UG metal integration. (#2580)Laurent Mazare2024-10-272-8/+4
|
* Tweak some metal tests. (#2528)Laurent Mazare2024-10-022-62/+23
|
* Efficient implementation of `Tensor::ones()` for `metal` (#2512)Anubhab Bandyopadhyay2024-10-013-0/+132
| | | | | | | | | | | | | * WIP: hopefully better const impl * with GPU * More tests on * Reverting primitive for * Incorporating review changes - added check elem count check in kerner, using for call strategy * rustfmt ran
* Bump the crate version to 0.7.2. (#2517)Laurent Mazare2024-09-291-1/+1
|
* Move the candle version to 0.7.1. (#2495)Laurent Mazare2024-09-221-1/+1
|
* Bump the crate version. (#2491)Laurent Mazare2024-09-211-1/+1
|
* Bugfix for the metal elu kernel. (#2490)Laurent Mazare2024-09-211-1/+1
| | | | | * Bugfix for the metal elu kernel. * Add a test.
* Metal commands refactoring (#2489)Laurent Mazare2024-09-211-5/+28
| | | | | | | | | * Split out the commands part of the metal device. * Make most fields private. * Move the allocator back. * Rework the encoder provider type.
* Fix for metal tanh. (#2475)Laurent Mazare2024-09-131-3/+8
|
* Add some metal gemm benchark. (#2471)Laurent Mazare2024-09-112-0/+138
| | | | | * Add some metal gemm benchark. * More benchmarks.
* Integrate the MLX gemm kernels (#2468)Laurent Mazare2024-09-115-55/+1874
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | * Include the MLX gemm kernels. * Clippy lints. * Export the gemm_f32 kernel. * Add the f16/bf16 variants. * Add the initial dispatch code. * More plugging of the mlx kernels. * Add a currently broken test. * Tweaks. * Bugfix + get the tests to pass. * Enable the gemm bf16 tests. * Add some randomized tests. * Update candle-metal-kernels/src/lib.rs Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> * More fixes. * More clippy fixes. --------- Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com>
* Bump the version to 0.6.1. (#2438)Laurent Mazare2024-08-221-1/+1
|
* Revert the bf16 gemm metal changes for now. (#2386)Laurent Mazare2024-08-012-19/+21
|
* Metal bgemm min changes (#2364)ivarflakstad2024-08-013-4/+76
| | | | | * Add updated mfa metallib * Add bgemm and tests
* Enable the affine kernel for u8/u32. (#2376)Laurent Mazare2024-08-011-0/+2
|
* Use RAII for terminating the encoding. (#2353)Laurent Mazare2024-07-242-61/+69
|
* Use a trait for the encoder provider (so that encoder can ultimately be ↵Laurent Mazare2024-07-242-120/+143
| | | | reused). (#2352)
* Bump the crate version. (#2248)Laurent Mazare2024-06-051-1/+1
|
* add where_cond f32 for metal (#2236)Lionel Touati2024-06-021-0/+21
|
* Add a metal kernel for col2im1d. (#2214)Laurent Mazare2024-05-252-1/+97
| | | | | | | | | * Add a metal kernel for col2im1d. * Enable the col2im variant. * Bugfix. * Revert the quantized tweak.
* Add the layernorm specialized op. (#2212)Laurent Mazare2024-05-242-0/+144
| | | | | | | | | | | | | | | * Add the layernorm cuda kernels. * Dedicated layer norm op. * Add the slower variant. * Plug the cuda implementation. * Add the metal variant. * Add a dedicated test. * Bugfix.
* Add some missing where-cond kernels for metal. (#2203)Laurent Mazare2024-05-221-14/+17
|
* Separate quantized phi-3 implementation. (#2157)Laurent Mazare2024-05-041-1/+1
| | | | | | | | | | | * Separate quantized phi-3 implementation. * Integrate the quantized phi3 model.= * Small fixes, get the generation to work properly. * Keep the old llama implementation around. * Change the default.
* Bump the version number to 0.5.1. (#2155)Laurent Mazare2024-05-031-1/+1
| | | | | | | * Bump the version number to 0.5.1. * Fix clippy lints for 1.78. * More clippy fixes.
* Fix sigmoid gradient calculation and move sigmoid into a specialized op (#2114)MilkFather2024-04-292-1/+6
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | * add sigmoid op * small fix * add as a method on `Tensor` * implement gradient calculation for sigmoid * add sigmoid tests * we should have a specialized op for this * fix clippy * fix clippy 2 * Revert all previous commits in favor of a `CustomOp` based solution * use `CustomOp1` implementation * fix rustfmt * experimental add metal impl * add cuda kernel impl * fix fmt * Add a test + reduce some cuda duplication. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
* Add argsort. (#2132)Laurent Mazare2024-04-273-0/+138
| | | | | | | | | | | | | | | | | | | | | * Add the argsort cuda kernels. * CPU version of arg-sort. * Hook the cuda kernel + rework the cpu bits. * Add some dedicated test. * Working cuda kernel. * Metal kernel. * Metal adjustments. * Bugfix. * Use the fast rope in qwen. * Rework the expert selection in qwen.
* Metal Unary: Add benchmarks and process kernels in a tile based fashion (#2056)Thomas Santerre2024-04-212-37/+97
| | | | | | | | | | | | | | | | | * add basic unary bench for sqrt * process unary commands in tiles of 4 * re-enable all benchmarks * rename helper to unary * modify approach to split up tiled and non-tiled operations * undo bench ignore for other tests * update tile size to 2 * only perform the optimization on the contiguous even numbered element case
* Handle multiple dimensions in metal QMM + two fixes. (#2097)Laurent Mazare2024-04-201-7/+8
|
* Add missing bfloat unary strided kernels and fix typo (#2058)ivarflakstad2024-04-141-1/+1
|
* Support gather on bf16 for metal. (#2035)Laurent Mazare2024-04-101-0/+3
|
* Use BufferOffset in metal backend ops. (#2029)Laurent Mazare2024-04-082-128/+78
| | | | | | | * Use BufferOffset in the metal backend. * More BufferOffset usage. * Use in where-cond.
* Rework the buffer offset logic for metal kernels (#2028)Laurent Mazare2024-04-073-247/+262
| | | | | | | | | | | | | | | * Move the metal kernels utils in a separate module. * Use the BufferOffset for unary ops. * Fix clippy lints. * Use the new BufferOffset. * Adapt the binary ops. * Affine. * More ops (powf, elu, cast).
* Optimize copy-2d for metal. (#2024)Laurent Mazare2024-04-072-20/+57
| | | | | * Optimize copy-2d for metal. * Add a hacky stopping rule for moondream.
* Add the rope THD kernel. (#2014)Laurent Mazare2024-04-052-4/+89
| | | | | | | | | * Add the rope THD kernel. * Cuda kernel for rope-thd. * Add the metal kernels. * Add a dedicated test.
* Add support for "sign" on tensors (#2012)Thomas Santerre2024-04-042-1/+3
| | | | | | | | | | | | | | | | | | | | | | | | | | | * add the sign unary operator * remove uneeded import * remove uneeded import * undo formatting * undo formatting * remove unnecessary redefintion * allow gradient to flow through for sign and round * fix cpu ops to ensure that negzero and positive zero are handled properly * clippy fixes * Properly avoid gradient tracking. * Use a branchless version. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
* update dtypes checks for several metal operations (#2010)Thomas Santerre2024-04-042-6/+20
|
* Bumping the version number to 0.5.0. (#2009)Laurent Mazare2024-04-041-1/+1
|
* Minor cleanups in reduce.metal. (#2004)Laurent Mazare2024-04-041-23/+1
|
* refactor to reduce the amount of code wrapped in template syntax (#2002)Thomas Santerre2024-04-041-261/+368
|
* Fix for the RWKV models. (#1955)Laurent Mazare2024-03-281-4/+4
| | | | | | | * Fix for the RWKV models. * More general fix + revert the rwkv hack. * Remove the old hack.
* Support i64 in index-select on metal. (#1951)Laurent Mazare2024-03-271-1/+7
| | | | | * Support i64 in index-select on metal. * Add some testing of index-select for all dtypes.