summaryrefslogtreecommitdiff
path: root/candle-pyo3/test_pytorch.py
diff options
context:
space:
mode:
authorEric Buehler <65165915+EricLBuehler@users.noreply.github.com>2024-11-05 03:28:00 -0500
committerGitHub <noreply@github.com>2024-11-05 09:28:00 +0100
commite2b6b367fa852ed30ac532f8d77cd8479c7ed092 (patch)
tree41321e646a0ee9abef88122b202bd940240ecae6 /candle-pyo3/test_pytorch.py
parent6454597943599dd6df787a0d5f2446c5724d850a (diff)
downloadcandle-e2b6b367fa852ed30ac532f8d77cd8479c7ed092.tar.gz
candle-e2b6b367fa852ed30ac532f8d77cd8479c7ed092.tar.bz2
candle-e2b6b367fa852ed30ac532f8d77cd8479c7ed092.zip
Add some fast Metal MLX SDPA kernels (#2584)
* Add some fast Metal MLX SDPA kernels (#32) * Sketch the sdpa kernel * Add full sdpa kernel, * Add test * Add vectorized kernel for decoding * Update tests * Add some docs * Fix sdpa_vector names * Add softcapping for vectorized sdpa * Add softcapping for full sdpa * Add support for head dim 32, 96, 256 * Add support for head dim 32, 96, 256 * Update docs * Add update notice * Clippy and format * Conditional compilation for bf16 * Use it in quantized llama * Some review comments * Use set_params! * Remove unused * Remove feature * Fix metal sdpa for v stride * Remove comma * Add the dim method to layout and shape. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-pyo3/test_pytorch.py')
0 files changed, 0 insertions, 0 deletions