summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r--candle-metal-kernels/src/lib.rs211
1 files changed, 195 insertions, 16 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 743b9fe2..a595b2bd 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -11,33 +11,35 @@ pub use utils::BufferOffset;
use utils::{get_block_dims, linear_split, EncoderProvider};
const AFFINE: &str = include_str!("affine.metal");
-const INDEXING: &str = include_str!("indexing.metal");
-const UNARY: &str = include_str!("unary.metal");
const BINARY: &str = include_str!("binary.metal");
-const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const CONV: &str = include_str!("conv.metal");
-const REDUCE: &str = include_str!("reduce.metal");
-const RANDOM: &str = include_str!("random.metal");
+const INDEXING: &str = include_str!("indexing.metal");
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
+const MLX_GEMM: &str = include_str!("mlx_gemm.metal");
const QUANTIZED: &str = include_str!("quantized.metal");
+const RANDOM: &str = include_str!("random.metal");
+const REDUCE: &str = include_str!("reduce.metal");
const SORT: &str = include_str!("sort.metal");
+const TERNARY: &str = include_str!("ternary.metal");
+const UNARY: &str = include_str!("unary.metal");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
Affine,
- Indexing,
- Unary,
Binary,
- Ternary,
Cast,
- Reduce,
- Mfa,
Conv,
- Random,
+ Gemm,
+ Indexing,
+ Mfa,
Quantized,
+ Random,
+ Reduce,
Sort,
+ Ternary,
+ Unary,
}
pub mod copy2d {
@@ -191,16 +193,17 @@ impl Kernels {
fn get_library_source(&self, source: Source) -> &'static str {
match source {
Source::Affine => AFFINE,
- Source::Unary => UNARY,
Source::Binary => BINARY,
- Source::Ternary => TERNARY,
- Source::Indexing => INDEXING,
Source::Cast => CAST,
- Source::Reduce => REDUCE,
Source::Conv => CONV,
- Source::Random => RANDOM,
+ Source::Gemm => MLX_GEMM,
+ Source::Indexing => INDEXING,
Source::Quantized => QUANTIZED,
+ Source::Random => RANDOM,
+ Source::Reduce => REDUCE,
Source::Sort => SORT,
+ Source::Ternary => TERNARY,
+ Source::Unary => UNARY,
Source::Mfa => panic!("Invalid lib"),
}
}
@@ -2178,5 +2181,181 @@ pub fn call_arg_sort(
Ok(())
}
+#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
+pub enum GemmDType {
+ BF16,
+ F16,
+ F32,
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_mlx_gemm(
+ device: &Device,
+ ep: impl EncoderProvider,
+ kernels: &Kernels,
+ dtype: GemmDType,
+ (b, m, n, k): (usize, usize, usize, usize),
+ lhs_stride: &[usize],
+ lhs_offset: usize,
+ lhs_buffer: &Buffer,
+ rhs_stride: &[usize],
+ rhs_offset: usize,
+ rhs_buffer: &Buffer,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ #[derive(Debug)]
+ #[repr(C)]
+ struct GemmParams {
+ m: i32,
+ n: i32,
+ k: i32,
+ lda: i32,
+ ldb: i32,
+ ldd: i32,
+ tiles_n: i32,
+ tiles_m: i32,
+ batch_stride_a: isize,
+ batch_stride_b: isize,
+ batch_stride_d: isize,
+ swizzle_log: i32,
+ gemm_k_iterations_aligned: i32,
+ batch_ndim: i32,
+ }
+ assert!(rhs_stride.len() >= 2);
+ assert!(lhs_stride.len() >= 2);
+ let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
+ let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
+ let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
+ let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
+ // lhs has shape b, m, k
+ // We also allow for the case where the stride on the minor dimension is not as expected but
+ // there is a single element.
+ let (lda, a_trans) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
+ (k as i32, false)
+ } else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
+ (m as i32, true)
+ } else {
+ return Err(MetalKernelError::MatMulNonContiguous {
+ lhs_stride: lhs_stride.to_vec(),
+ rhs_stride: rhs_stride.to_vec(),
+ mnk: (m, n, k),
+ })?;
+ };
+ // rhs has shape b, k, n
+ let (ldb, b_trans) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
+ (n as i32, false)
+ } else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
+ (k as i32, true)
+ } else {
+ return Err(MetalKernelError::MatMulNonContiguous {
+ lhs_stride: lhs_stride.to_vec(),
+ rhs_stride: rhs_stride.to_vec(),
+ mnk: (m, n, k),
+ })?;
+ };
+ let (bm, bn, bk, wn, wm) = (32, 32, 16, 2, 2);
+ // https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/matmul.cpp#L422
+ let constants = Some(ConstantValues::new(vec![
+ (10, Value::Bool(/* has_batch */ b > 1)),
+ (100, Value::Bool(/* use_out_source */ false)),
+ (110, Value::Bool(/* do_axpby */ false)),
+ (200, Value::Bool(/* align_m */ m % bm == 0)),
+ (201, Value::Bool(/* align_n */ n % bn == 0)),
+ (202, Value::Bool(/* align_k */ k % bk == 0)),
+ (300, Value::Bool(/* do_gather */ false)),
+ ]));
+
+ let swizzle_log = 0;
+ let tile = 1 << swizzle_log;
+ let tn = n.div_ceil(bn);
+ let tm = m.div_ceil(bm);
+ let tn = tn * tile;
+ let tm = tm.div_ceil(tile);
+
+ let batch_stride_a = if lhs_stride.len() > 2 {
+ lhs_stride[lhs_stride.len() - 3]
+ } else {
+ m * k
+ };
+ let batch_stride_b = if rhs_stride.len() > 2 {
+ rhs_stride[rhs_stride.len() - 3]
+ } else {
+ n * k
+ };
+
+ let gemm_params = GemmParams {
+ m: m as i32,
+ n: n as i32,
+ k: k as i32,
+ lda,
+ ldb,
+ ldd: n as i32,
+ tiles_n: tn as i32,
+ tiles_m: tm as i32,
+ swizzle_log,
+ batch_stride_a: batch_stride_a as isize,
+ batch_stride_b: batch_stride_b as isize,
+ batch_stride_d: (m * n) as isize,
+ batch_ndim: 1i32,
+ gemm_k_iterations_aligned: (k / bk) as i32,
+ };
+ let batch_strides = [gemm_params.batch_stride_a, gemm_params.batch_stride_b];
+
+ // TODO(laurent): generate the name
+ // template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]]
+ let name = match (dtype, a_trans, b_trans) {
+ (GemmDType::F32, false, false) => "gemm_nn_f32_f32_32_32_16_2_2",
+ (GemmDType::F32, true, false) => "gemm_tn_f32_f32_32_32_16_2_2",
+ (GemmDType::F32, false, true) => "gemm_nt_f32_f32_32_32_16_2_2",
+ (GemmDType::F32, true, true) => "gemm_tt_f32_f32_32_32_16_2_2",
+ (GemmDType::BF16, false, false) => "gemm_nn_bf16_bf16_32_32_16_2_2",
+ (GemmDType::BF16, true, false) => "gemm_tn_bf16_bf16_32_32_16_2_2",
+ (GemmDType::BF16, false, true) => "gemm_nt_bf16_bf16_32_32_16_2_2",
+ (GemmDType::BF16, true, true) => "gemm_tt_bf16_bf16_32_32_16_2_2",
+ (GemmDType::F16, false, false) => "gemm_nn_f16_f16_32_32_16_2_2",
+ (GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2",
+ (GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2",
+ (GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2",
+ };
+ let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?;
+ let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
+ encoder.set_compute_pipeline_state(&pipeline);
+ encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
+ encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
+ encoder.set_buffer(3, Some(output), 0);
+ encoder.set_bytes(
+ 4,
+ std::mem::size_of::<GemmParams>() as u64,
+ &gemm_params as *const GemmParams as *const c_void,
+ );
+ encoder.set_bytes(
+ 6, // batch_shape
+ std::mem::size_of::<i32>() as u64,
+ &(b as i32) as *const i32 as *const c_void,
+ );
+ encoder.set_bytes(
+ 7,
+ (std::mem::size_of::<isize>() * batch_strides.len()) as u64,
+ batch_strides.as_ptr() as *const c_void,
+ );
+
+ let grid_size = MTLSize {
+ width: tn as u64,
+ height: tm as u64,
+ depth: /* batch_size_out */ b as u64,
+ };
+ let group_size = MTLSize {
+ width: 32,
+ height: wn,
+ depth: wm,
+ };
+ encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(grid_size, group_size);
+ Ok(())
+}
+
#[cfg(test)]
mod tests;