From 2ce5f12513d0dafb04c7e345da9d4fba566cfa16 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 26 Jul 2023 14:16:37 +0100 Subject: Again set a few extra params in flash-attn. (#245) * Again set a few extra params. * Use the appropriate kernel sizes. * Add all the kernel sizes. * Parallel compiling. * Reduce the amount of parallelism. * Add the missing kernel. * Fix a typo. * Remove bf16 support for now. --- candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu (limited to 'candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu') diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu new file mode 100644 index 00000000..982fe7ea --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} -- cgit v1.2.3