diff options
Diffstat (limited to 'candle-flash-attn/kernels/flash_api.cu')
-rw-r--r-- | candle-flash-attn/kernels/flash_api.cu | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu index 4ca41b0a..00933419 100644 --- a/candle-flash-attn/kernels/flash_api.cu +++ b/candle-flash-attn/kernels/flash_api.cu @@ -55,7 +55,9 @@ extern "C" void run_mha( int is_causal, int window_size_left, - int window_size_right + int window_size_right, + + float softcap ) { Flash_fwd_params params; // Reset the parameters @@ -99,8 +101,16 @@ extern "C" void run_mha( params.d_rounded = d_rounded; // Set the different scale values. - params.scale_softmax = softmax_scale; - params.scale_softmax_log2 = softmax_scale * M_LOG2E; + if (softcap > 0.0) { + params.softcap = softmax_scale / softcap; + params.scale_softmax = softcap; + params.scale_softmax_log2 = softcap * M_LOG2E; + } else{ + // Remove potential NaN + params.softcap = 0.0; + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + } params.p_dropout = 1.; // probability to keep params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); |