summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/flash_fwd_launch_template.h
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/kernels/flash_fwd_launch_template.h')
-rw-r--r--candle-flash-attn/kernels/flash_fwd_launch_template.h15
1 files changed, 9 insertions, 6 deletions
diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h
index 9e5449d7..bb581eb3 100644
--- a/candle-flash-attn/kernels/flash_fwd_launch_template.h
+++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h
@@ -3,11 +3,11 @@
******************************************************************************/
#pragma once
-
-// #include <ATen/cuda/CUDAContext.h>
+// #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#include "error.h"
#include "static_switch.h"
+#include "hardware_info.h"
#include "flash.h"
#include "flash_fwd_kernel.h"
@@ -74,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
- auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout>;
+ auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
@@ -205,7 +205,8 @@ inline bool cuda_is_sm8x() {
template<typename T, bool Is_causal>
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96;
- bool is_sm8x = cuda_is_sm8x();
+ auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
+ bool is_sm8x = cc_major == 8 && cc_minor > 0;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
if (is_sm8x) {
@@ -228,7 +229,8 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T, bool Is_causal>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
- bool is_sm8x = cuda_is_sm8x();
+ auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
+ bool is_sm8x = cc_major == 8 && cc_minor > 0;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if constexpr(!Is_dropout) {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
@@ -262,7 +264,8 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T, bool Is_causal>
void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 160;
- bool is_sm8x = cuda_is_sm8x();
+ auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
+ bool is_sm8x = cc_major == 8 && cc_minor > 0;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// For A100, H100, 128 x 32 is the fastest.
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),