diff options
Diffstat (limited to 'candle-flash-attn/kernels/kernels.h')
-rw-r--r-- | candle-flash-attn/kernels/kernels.h | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/candle-flash-attn/kernels/kernels.h b/candle-flash-attn/kernels/kernels.h new file mode 100644 index 00000000..20d6605f --- /dev/null +++ b/candle-flash-attn/kernels/kernels.h @@ -0,0 +1,58 @@ +#ifndef _GPU_OPS_KERNELS_H_ +#define _GPU_OPS_KERNELS_H_ + +#include <cuda_runtime_api.h> + +#include <cstddef> +#include <cstdint> + +#include<stdlib.h> +#include<stdint.h> + +namespace gpu_ops { + +struct MHAParams { + uint32_t q_batch_stride; + uint32_t k_batch_stride; + uint32_t v_batch_stride; + uint32_t o_batch_stride; + + uint32_t q_row_stride; + uint32_t k_row_stride; + uint32_t v_row_stride; + uint32_t o_row_stride; + + uint32_t q_head_stride; + uint32_t k_head_stride; + uint32_t v_head_stride; + uint32_t o_head_stride; + + uint32_t b; + uint32_t h; + uint32_t h_k; + uint32_t d; + uint32_t d_rounded; + float softmax_scale; + float softcap; + + uint32_t seqlen_q; + uint32_t seqlen_k; + uint32_t seqlen_q_rounded; + uint32_t seqlen_k_rounded; + + int window_size_left; + int window_size_right; + + int is_causal; + int is_bf16; +}; + +void run_mha_fwd_j(cudaStream_t stream, void **buffers, + const char *opaque, + std::size_t opaque_len); +void run_mha_bwd_j(cudaStream_t stream, void **buffers, + const char *opaque, + std::size_t opaque_len); +} + +#endif |