summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/flash_api.cu
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/kernels/flash_api.cu')
-rw-r--r--candle-flash-attn/kernels/flash_api.cu16
1 files changed, 13 insertions, 3 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu
index 4ca41b0a..00933419 100644
--- a/candle-flash-attn/kernels/flash_api.cu
+++ b/candle-flash-attn/kernels/flash_api.cu
@@ -55,7 +55,9 @@ extern "C" void run_mha(
int is_causal,
int window_size_left,
- int window_size_right
+ int window_size_right,
+
+ float softcap
) {
Flash_fwd_params params;
// Reset the parameters
@@ -99,8 +101,16 @@ extern "C" void run_mha(
params.d_rounded = d_rounded;
// Set the different scale values.
- params.scale_softmax = softmax_scale;
- params.scale_softmax_log2 = softmax_scale * M_LOG2E;
+ if (softcap > 0.0) {
+ params.softcap = softmax_scale / softcap;
+ params.scale_softmax = softcap;
+ params.scale_softmax_log2 = softcap * M_LOG2E;
+ } else{
+ // Remove potential NaN
+ params.softcap = 0.0;
+ params.scale_softmax = softmax_scale;
+ params.scale_softmax_log2 = softmax_scale * M_LOG2E;
+ }
params.p_dropout = 1.; // probability to keep
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));