diff options
Diffstat (limited to 'candle-flash-attn/kernels/kernel_helpers.h')
-rw-r--r-- | candle-flash-attn/kernels/kernel_helpers.h | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/candle-flash-attn/kernels/kernel_helpers.h b/candle-flash-attn/kernels/kernel_helpers.h new file mode 100644 index 00000000..22e40cc4 --- /dev/null +++ b/candle-flash-attn/kernels/kernel_helpers.h @@ -0,0 +1,50 @@ +// This header is not specific to our application and you'll probably want +// something like this for any extension you're building. This includes the +// infrastructure needed to serialize descriptors that are used with the +// "opaque" parameter of the GPU custom call. In our example we'll use this +// parameter to pass the size of our problem. + +#ifndef _GPU_OPS_KERNEL_HELPERS_H_ +#define _GPU_OPS_KERNEL_HELPERS_H_ + +#include <cstdint> +#include <stdexcept> +#include <string> +#include <type_traits> + +#define JAX_APEX_WARP_SIZE 32 + +namespace gpu_ops { + +// https://en.cppreference.com/w/cpp/numeric/bit_cast +template <class To, class From> +typename std::enable_if<sizeof(To) == sizeof(From) && + std::is_trivially_copyable<From>::value && + std::is_trivially_copyable<To>::value, + To>::type +bit_cast(const From &src) noexcept { + static_assert(std::is_trivially_constructible<To>::value, + "This implementation additionally requires destination type to " + "be trivially constructible"); + + To dst; + memcpy(&dst, &src, sizeof(To)); + return dst; +} + +template <typename T> std::string PackDescriptorAsString(const T &descriptor) { + return std::string(bit_cast<const char *>(&descriptor), sizeof(T)); +} + +template <typename T> +const T *UnpackDescriptor(const char *opaque, std::size_t opaque_len) { + if (opaque_len != sizeof(T)) { + throw std::runtime_error("Invalid opaque object size"); + } + return bit_cast<const T *>(opaque); +} + +} // namespace gpu_ops + +#endif + |