summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/flash_api.cu
diff options
context:
space:
mode:
authorMichael Feil <63565275+michaelfeil@users.noreply.github.com>2024-12-31 09:41:23 +0100
committerGitHub <noreply@github.com>2024-12-31 09:41:23 +0100
commita594ef669ca5ed82c1f19d2230b4b3dc9cb46f43 (patch)
tree8647429f4c0ae7fddbae84a1936819f0c0172514 /candle-flash-attn/kernels/flash_api.cu
parent71cd6d55337b1541f602c1afffa6baf6dd75b09c (diff)
downloadcandle-a594ef669ca5ed82c1f19d2230b4b3dc9cb46f43.tar.gz
candle-a594ef669ca5ed82c1f19d2230b4b3dc9cb46f43.tar.bz2
candle-a594ef669ca5ed82c1f19d2230b4b3dc9cb46f43.zip
Flash-Attn upgrade / SoftCap Candle-FlashAttn [2/n] (#2689)
* update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace * softcap is working, including test and api. * make softcap test case better --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-flash-attn/kernels/flash_api.cu')
-rw-r--r--candle-flash-attn/kernels/flash_api.cu16
1 files changed, 13 insertions, 3 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu
index 4ca41b0a..00933419 100644
--- a/candle-flash-attn/kernels/flash_api.cu
+++ b/candle-flash-attn/kernels/flash_api.cu
@@ -55,7 +55,9 @@ extern "C" void run_mha(
int is_causal,
int window_size_left,
- int window_size_right
+ int window_size_right,
+
+ float softcap
) {
Flash_fwd_params params;
// Reset the parameters
@@ -99,8 +101,16 @@ extern "C" void run_mha(
params.d_rounded = d_rounded;
// Set the different scale values.
- params.scale_softmax = softmax_scale;
- params.scale_softmax_log2 = softmax_scale * M_LOG2E;
+ if (softcap > 0.0) {
+ params.softcap = softmax_scale / softcap;
+ params.scale_softmax = softcap;
+ params.scale_softmax_log2 = softcap * M_LOG2E;
+ } else{
+ // Remove potential NaN
+ params.softcap = 0.0;
+ params.scale_softmax = softmax_scale;
+ params.scale_softmax_log2 = softmax_scale * M_LOG2E;
+ }
params.p_dropout = 1.; // probability to keep
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));