diff options
author | Michael Feil <63565275+michaelfeil@users.noreply.github.com> | 2024-12-31 09:41:23 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-31 09:41:23 +0100 |
commit | a594ef669ca5ed82c1f19d2230b4b3dc9cb46f43 (patch) | |
tree | 8647429f4c0ae7fddbae84a1936819f0c0172514 /candle-flash-attn/kernels/flash_api.cu | |
parent | 71cd6d55337b1541f602c1afffa6baf6dd75b09c (diff) | |
download | candle-a594ef669ca5ed82c1f19d2230b4b3dc9cb46f43.tar.gz candle-a594ef669ca5ed82c1f19d2230b4b3dc9cb46f43.tar.bz2 candle-a594ef669ca5ed82c1f19d2230b4b3dc9cb46f43.zip |
Flash-Attn upgrade / SoftCap Candle-FlashAttn [2/n] (#2689)
* update flash-attn v1
* restore: hdim224
* add 224 flash_fwd_template
* remove whitespace
* softcap is working, including test and api.
* make softcap test case better
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-flash-attn/kernels/flash_api.cu')
-rw-r--r-- | candle-flash-attn/kernels/flash_api.cu | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu index 4ca41b0a..00933419 100644 --- a/candle-flash-attn/kernels/flash_api.cu +++ b/candle-flash-attn/kernels/flash_api.cu @@ -55,7 +55,9 @@ extern "C" void run_mha( int is_causal, int window_size_left, - int window_size_right + int window_size_right, + + float softcap ) { Flash_fwd_params params; // Reset the parameters @@ -99,8 +101,16 @@ extern "C" void run_mha( params.d_rounded = d_rounded; // Set the different scale values. - params.scale_softmax = softmax_scale; - params.scale_softmax_log2 = softmax_scale * M_LOG2E; + if (softcap > 0.0) { + params.softcap = softmax_scale / softcap; + params.scale_softmax = softcap; + params.scale_softmax_log2 = softcap * M_LOG2E; + } else{ + // Remove potential NaN + params.softcap = 0.0; + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + } params.p_dropout = 1.; // probability to keep params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); |