diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-09-11 15:56:48 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-11 16:56:48 +0200 |
commit | 5635650d386ff7cfeb0dee84d02e8d19574c6faf (patch) | |
tree | a22afae6c32746b945ce57063f4057599bd68670 /candle-metal-kernels | |
parent | 13b2a8a4a06ae306819d3d790906435e5f247ae5 (diff) | |
download | candle-5635650d386ff7cfeb0dee84d02e8d19574c6faf.tar.gz candle-5635650d386ff7cfeb0dee84d02e8d19574c6faf.tar.bz2 candle-5635650d386ff7cfeb0dee84d02e8d19574c6faf.zip |
Integrate the MLX gemm kernels (#2468)
* Include the MLX gemm kernels.
* Clippy lints.
* Export the gemm_f32 kernel.
* Add the f16/bf16 variants.
* Add the initial dispatch code.
* More plugging of the mlx kernels.
* Add a currently broken test.
* Tweaks.
* Bugfix + get the tests to pass.
* Enable the gemm bf16 tests.
* Add some randomized tests.
* Update candle-metal-kernels/src/lib.rs
Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com>
* More fixes.
* More clippy fixes.
---------
Co-authored-by: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com>
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/Cargo.toml | 1 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 211 | ||||
-rw-r--r-- | candle-metal-kernels/src/mlx_gemm.metal | 1440 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 269 | ||||
-rw-r--r-- | candle-metal-kernels/src/utils.rs | 8 |
5 files changed, 1874 insertions, 55 deletions
diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index a93d8729..772452c9 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -23,3 +23,4 @@ half = { version = "2.3.1", features = [ "rand_distr", ] } rand = "0.8.5" +rand_distr = "0.4.3" diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 743b9fe2..a595b2bd 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -11,33 +11,35 @@ pub use utils::BufferOffset; use utils::{get_block_dims, linear_split, EncoderProvider}; const AFFINE: &str = include_str!("affine.metal"); -const INDEXING: &str = include_str!("indexing.metal"); -const UNARY: &str = include_str!("unary.metal"); const BINARY: &str = include_str!("binary.metal"); -const TERNARY: &str = include_str!("ternary.metal"); const CAST: &str = include_str!("cast.metal"); const CONV: &str = include_str!("conv.metal"); -const REDUCE: &str = include_str!("reduce.metal"); -const RANDOM: &str = include_str!("random.metal"); +const INDEXING: &str = include_str!("indexing.metal"); // Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); +const MLX_GEMM: &str = include_str!("mlx_gemm.metal"); const QUANTIZED: &str = include_str!("quantized.metal"); +const RANDOM: &str = include_str!("random.metal"); +const REDUCE: &str = include_str!("reduce.metal"); const SORT: &str = include_str!("sort.metal"); +const TERNARY: &str = include_str!("ternary.metal"); +const UNARY: &str = include_str!("unary.metal"); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Source { Affine, - Indexing, - Unary, Binary, - Ternary, Cast, - Reduce, - Mfa, Conv, - Random, + Gemm, + Indexing, + Mfa, Quantized, + Random, + Reduce, Sort, + Ternary, + Unary, } pub mod copy2d { @@ -191,16 +193,17 @@ impl Kernels { fn get_library_source(&self, source: Source) -> &'static str { match source { Source::Affine => AFFINE, - Source::Unary => UNARY, Source::Binary => BINARY, - Source::Ternary => TERNARY, - Source::Indexing => INDEXING, Source::Cast => CAST, - Source::Reduce => REDUCE, Source::Conv => CONV, - Source::Random => RANDOM, + Source::Gemm => MLX_GEMM, + Source::Indexing => INDEXING, Source::Quantized => QUANTIZED, + Source::Random => RANDOM, + Source::Reduce => REDUCE, Source::Sort => SORT, + Source::Ternary => TERNARY, + Source::Unary => UNARY, Source::Mfa => panic!("Invalid lib"), } } @@ -2178,5 +2181,181 @@ pub fn call_arg_sort( Ok(()) } +#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] +pub enum GemmDType { + BF16, + F16, + F32, +} + +#[allow(clippy::too_many_arguments)] +pub fn call_mlx_gemm( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + dtype: GemmDType, + (b, m, n, k): (usize, usize, usize, usize), + lhs_stride: &[usize], + lhs_offset: usize, + lhs_buffer: &Buffer, + rhs_stride: &[usize], + rhs_offset: usize, + rhs_buffer: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + #[derive(Debug)] + #[repr(C)] + struct GemmParams { + m: i32, + n: i32, + k: i32, + lda: i32, + ldb: i32, + ldd: i32, + tiles_n: i32, + tiles_m: i32, + batch_stride_a: isize, + batch_stride_b: isize, + batch_stride_d: isize, + swizzle_log: i32, + gemm_k_iterations_aligned: i32, + batch_ndim: i32, + } + assert!(rhs_stride.len() >= 2); + assert!(lhs_stride.len() >= 2); + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + // lhs has shape b, m, k + // We also allow for the case where the stride on the minor dimension is not as expected but + // there is a single element. + let (lda, a_trans) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { + (k as i32, false) + } else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) { + (m as i32, true) + } else { + return Err(MetalKernelError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?; + }; + // rhs has shape b, k, n + let (ldb, b_trans) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { + (n as i32, false) + } else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) { + (k as i32, true) + } else { + return Err(MetalKernelError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })?; + }; + let (bm, bn, bk, wn, wm) = (32, 32, 16, 2, 2); + // https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/matmul.cpp#L422 + let constants = Some(ConstantValues::new(vec![ + (10, Value::Bool(/* has_batch */ b > 1)), + (100, Value::Bool(/* use_out_source */ false)), + (110, Value::Bool(/* do_axpby */ false)), + (200, Value::Bool(/* align_m */ m % bm == 0)), + (201, Value::Bool(/* align_n */ n % bn == 0)), + (202, Value::Bool(/* align_k */ k % bk == 0)), + (300, Value::Bool(/* do_gather */ false)), + ])); + + let swizzle_log = 0; + let tile = 1 << swizzle_log; + let tn = n.div_ceil(bn); + let tm = m.div_ceil(bm); + let tn = tn * tile; + let tm = tm.div_ceil(tile); + + let batch_stride_a = if lhs_stride.len() > 2 { + lhs_stride[lhs_stride.len() - 3] + } else { + m * k + }; + let batch_stride_b = if rhs_stride.len() > 2 { + rhs_stride[rhs_stride.len() - 3] + } else { + n * k + }; + + let gemm_params = GemmParams { + m: m as i32, + n: n as i32, + k: k as i32, + lda, + ldb, + ldd: n as i32, + tiles_n: tn as i32, + tiles_m: tm as i32, + swizzle_log, + batch_stride_a: batch_stride_a as isize, + batch_stride_b: batch_stride_b as isize, + batch_stride_d: (m * n) as isize, + batch_ndim: 1i32, + gemm_k_iterations_aligned: (k / bk) as i32, + }; + let batch_strides = [gemm_params.batch_stride_a, gemm_params.batch_stride_b]; + + // TODO(laurent): generate the name + // template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] + let name = match (dtype, a_trans, b_trans) { + (GemmDType::F32, false, false) => "gemm_nn_f32_f32_32_32_16_2_2", + (GemmDType::F32, true, false) => "gemm_tn_f32_f32_32_32_16_2_2", + (GemmDType::F32, false, true) => "gemm_nt_f32_f32_32_32_16_2_2", + (GemmDType::F32, true, true) => "gemm_tt_f32_f32_32_32_16_2_2", + (GemmDType::BF16, false, false) => "gemm_nn_bf16_bf16_32_32_16_2_2", + (GemmDType::BF16, true, false) => "gemm_tn_bf16_bf16_32_32_16_2_2", + (GemmDType::BF16, false, true) => "gemm_nt_bf16_bf16_32_32_16_2_2", + (GemmDType::BF16, true, true) => "gemm_tt_bf16_bf16_32_32_16_2_2", + (GemmDType::F16, false, false) => "gemm_nn_f16_f16_32_32_16_2_2", + (GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2", + (GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2", + (GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2", + }; + let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); + encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); + encoder.set_buffer(3, Some(output), 0); + encoder.set_bytes( + 4, + std::mem::size_of::<GemmParams>() as u64, + &gemm_params as *const GemmParams as *const c_void, + ); + encoder.set_bytes( + 6, // batch_shape + std::mem::size_of::<i32>() as u64, + &(b as i32) as *const i32 as *const c_void, + ); + encoder.set_bytes( + 7, + (std::mem::size_of::<isize>() * batch_strides.len()) as u64, + batch_strides.as_ptr() as *const c_void, + ); + + let grid_size = MTLSize { + width: tn as u64, + height: tm as u64, + depth: /* batch_size_out */ b as u64, + }; + let group_size = MTLSize { + width: 32, + height: wn, + depth: wm, + }; + encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(grid_size, group_size); + Ok(()) +} + #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/mlx_gemm.metal b/candle-metal-kernels/src/mlx_gemm.metal new file mode 100644 index 00000000..1b5cad92 --- /dev/null +++ b/candle-metal-kernels/src/mlx_gemm.metal @@ -0,0 +1,1440 @@ +// MLX Kernel extracted from: +// https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/steel/gemm +// Copyright © 2024 Apple Inc. + +#include <metal_simdgroup> +#include <metal_simdgroup_matrix> +#include <metal_stdlib> + +#define STEEL_CONST static constant constexpr const +#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") + +using namespace metal; + +// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/kernels/steel/gemm/params.h#L1 +/////////////////////////////////////////////////////////////////////////////// +// GEMM param classes +/////////////////////////////////////////////////////////////////////////////// + +struct GEMMParams { + const int M; + const int N; + const int K; + + const int lda; + const int ldb; + const int ldd; + + const int tiles_n; + const int tiles_m; + + const size_t batch_stride_a; + const size_t batch_stride_b; + const size_t batch_stride_d; + + const int swizzle_log; + const int gemm_k_iterations_aligned; + + const int batch_ndim; +}; + +struct GEMMSpiltKParams { + const int M; + const int N; + const int K; + + const int lda; + const int ldb; + const int ldc; + + const int tiles_n; + const int tiles_m; + + const int split_k_partitions; + const int split_k_partition_stride; + const int split_k_partition_size; + + const int gemm_k_iterations_aligned; +}; + +struct GEMMAddMMParams { + const int ldc; + const int fdc; + + const size_t batch_stride_c; + + const float alpha; + const float beta; +}; + +// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/kernels/steel/gemm/loader.h#L1 +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short alignment = 1, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoader { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + struct alignas(alignment * sizeof(T)) ReadVector { + uint8_t v[sizeof(T) * vec_size]; + }; + + /* Constructor */ + METAL_FUNC BlockLoader( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj) {} + + /* Apply operation to threadgroup without bound checking */ + template <typename UnaryOp> + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + *((threadgroup ReadVector*)(&dst[i * dst_ld])) = + *((const device ReadVector*)(&src[i * src_ld])); + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out uneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + +// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/kernels/steel/gemm/transforms.h#L1 +/////////////////////////////////////////////////////////////////////////////// +// Transforms and Epilogues +/////////////////////////////////////////////////////////////////////////////// + +template <typename OutT, typename InT> +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast<OutT>(x); + } + + static METAL_FUNC OutT apply(InT x, OutT) { + return static_cast<OutT>(x); + } +}; + +template <typename OutT, typename InT> +struct TransformAdd { + TransformAdd(const float, const float) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast<OutT>(x); + } + + static METAL_FUNC OutT apply(InT x, OutT c) { + return static_cast<OutT>(x) + c; + } +}; + +template <typename OutT, typename InT> +struct TransformAxpby { + const float alpha; + const float beta; + + TransformAxpby(const float alpha_, const float beta_) + : alpha(alpha_), beta(beta_) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast<OutT>(x); + } + + METAL_FUNC OutT apply(InT x, OutT c) const { + return static_cast<OutT>(x * alpha + (beta * c)); + } +}; + +template <typename T> +struct AccumHelper { + typedef float accum_type; +}; + +struct BlockSwizzle { + static METAL_FUNC int2 + swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { + const int tid_x = (tid.x) >> swizzle_log; + const int tid_y = + ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); + return int2(tid_x, tid_y); + } +}; + +// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/kernels/steel/gemm/mma.h#L1 +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType = float, + typename Epilogue = TransformNone<U, AccumType>> +struct BlockMMA { + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = 8 * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = 8 * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / TM_stride; + // Warp tile size along N + STEEL_CONST short TN = BN / TN_stride; + + // Strides of A, B along reduction axis + STEEL_CONST short simd_stride_a = { + transpose_a ? TM_stride : TM_stride * lda_tgp}; + STEEL_CONST short simd_stride_b = { + transpose_b ? TN_stride * ldb_tgp : TN_stride}; + + // Jump between elements + STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1}; + STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1}; + + STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; + STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; + + // Simdgroup matrices + simdgroup_matrix<AccumType, 8, 8> Asimd[TM]; + simdgroup_matrix<AccumType, 8, 8> Bsimd[TN]; + simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = { + simdgroup_matrix<AccumType, 8, 8>(0)}; + + // Offsets within threadgroup + const short tm; + const short tn; + + short sm; + short sn; + + short As_offset; + short Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { + // Determine thread position in simdgroup matrix + short qid = simd_lane_id / 4; + sm = (qid & 4) + (simd_lane_id / 2) % 4; + sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + + // Determine thread and simdgroup offset + As_offset = + transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); + Bs_offset = + transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + + // Iterate over BK in blocks of 8 + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += 8) { + simdgroup_barrier(mem_flags::mem_none); + + // Load elements from threadgroup A as simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + Asimd[i].thread_elements()[0] = + static_cast<AccumType>(As[i * simd_stride_a + 0]); + Asimd[i].thread_elements()[1] = + static_cast<AccumType>(As[i * simd_stride_a + jump_a]); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Load elements from threadgroup B as simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + Bsimd[j].thread_elements()[0] = + static_cast<AccumType>(Bs[j * simd_stride_b + 0]); + Bsimd[j].thread_elements()[1] = + static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Multiply and accumulate into result simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + short j_serp = (i % 2) ? (TN - 1 - j) : j; + + simdgroup_multiply_accumulate( + results[i * TN + j_serp], + Asimd[i], + Bsimd[j_serp], + results[i * TN + j_serp]); + } + } + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* D, const int ldd) const { + // Adjust for simdgroup and thread location + D += (sm + tm) * ldd + tn + sn; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; + + // Write out D + D[offset] = outs[0]; + D[offset + 1] = outs[1]; + } + } + } + + METAL_FUNC void + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) const { + // Adjust for simdgroup and thread location + D += (sm + tm) * ldd + (tn + sn); + dst_tile_dims -= short2(tn + sn, sm + tm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + D[offset] = Epilogue::apply(accum[0]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + D[offset + 1] = Epilogue::apply(accum[1]); + } + } + } + } + } + + /* Apply epilogue */ + template <typename UnaryEpilogue> + METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = results[i * TN + j].thread_elements(); + + // Apply epilogue + accum[0] = epilogue_op.apply(accum[0]); + accum[1] = epilogue_op.apply(accum[1]); + } + } + } + + /* Apply epilogue */ + template <typename BinaryEpilogue> + METAL_FUNC void apply_epilogue( + const device U* C, + const int ldc, + const int fdc, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + // Apply epilogue + accum[0] = epilogue_op.apply(accum[0], C[offset_c]); + accum[1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + } + } + } + + /* Apply epilogue */ + template <typename BinaryEpilogue> + METAL_FUNC void apply_epilogue_safe( + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + dst_tile_dims -= short2(tn + sn, sm + tm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + // Read C + U c_elems[2] = {0}; + + if ((j * TN_stride + 1) < dst_tile_dims.x) { + c_elems[0] = C[offset_c]; + c_elems[1] = C[offset_c + fdc]; + } else if ((j * TN_stride) < dst_tile_dims.x) { + c_elems[0] = C[offset_c]; + } + + // Apply epilogue + accum[0] = epilogue_op.apply(accum[0], c_elems[0]); + accum[1] = epilogue_op.apply(accum[1], c_elems[1]); + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + D += (sm + tm) * ldd + tn + sn; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + U outs[2] = { + epilogue_op.apply(accum[0], C[offset_c]), + epilogue_op.apply(accum[1], C[offset_c + fdc])}; + + // Write out D + D[offset_d] = outs[0]; + D[offset_d + 1] = outs[1]; + } + } + } + + METAL_FUNC void store_result_safe( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + D += (sm + tm) * ldd + tn + sn; + dst_tile_dims -= short2(tn + sn, sm + tm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + } + } + } + } + } +}; + +// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/kernels/steel/gemm/gemm.h#L1 +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel class +/////////////////////////////////////////////////////////////////////////////// + +template <bool M_aligned, bool N_aligned, bool K_aligned> +struct LoopAlignment {}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + typename AccumType = typename AccumHelper<T>::accum_type, + typename Epilogue = TransformNone<U, AccumType>> +struct GEMMKernel { + STEEL_CONST short tgp_padding_a = 16 / sizeof(T); + STEEL_CONST short tgp_padding_b = 16 / sizeof(T); + STEEL_CONST short tgp_mem_size_a = + transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); + STEEL_CONST short tgp_mem_size_b = + transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); + STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; + + STEEL_CONST short tgp_size = WM * WN * 32; + + using loader_a_t = BlockLoader< + T, + transpose_a ? BK : BM, + transpose_a ? BM : BK, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + !transpose_a, + tgp_size>; + using loader_b_t = BlockLoader< + T, + transpose_b ? BN : BK, + transpose_b ? BK : BN, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + transpose_b, + tgp_size>; + using mma_t = BlockMMA< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + AccumType, + Epilogue>; + + /* Main kernel function */ + template <bool M_aligned, bool N_aligned, bool K_aligned_> + static METAL_FUNC void gemm_loop( + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + const int gemm_k_iterations, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + thread mma_t& mma_op, + thread const short& tgp_bm, + thread const short& tgp_bn, + thread const short& lbk, + LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) { + // Appease the compiler + (void)l; + + short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + + short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + if (!K_aligned_) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + short2 tile_dims_A_last = + transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); + short2 tile_dims_B_last = + transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); + + loader_a.load_safe(tile_dims_A_last); + loader_b.load_safe(tile_dims_B_last); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device U* D [[buffer(2)]], + const constant GEMMParams* params [[buffer(3)]], + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Pacifying compiler + (void)lid; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Loop tail + if (!K_aligned) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; + + if (tgp_bm == BM && tgp_bn == BN) { + gemm_loop<true, true, K_aligned>( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result(D, params->ldd); + return; + + } else if (tgp_bn == BN) { + gemm_loop<false, true, K_aligned>( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else if (tgp_bm == BM) { + gemm_loop<true, false, K_aligned>( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else { + gemm_loop<false, false, K_aligned>( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + } + } + } +}; + +// utils.h +/////////////////////////////////////////////////////////////////////////////// +// Single Array with generic dims + +template <typename stride_t> +METAL_FUNC stride_t elem_to_loc( + uint elem, + device const int* shape, + device const stride_t* strides, + int ndim) { + stride_t loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + return loc; +} + +template <typename stride_t> +METAL_FUNC stride_t elem_to_loc( + uint elem, + constant const int* shape, + constant const stride_t* strides, + int ndim) { + stride_t loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + return loc; +} + +template <typename stride_t> +METAL_FUNC stride_t elem_to_loc( + stride_t elem, + device const int* shape, + device const stride_t* strides, + int ndim) { + stride_t loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + return loc; +} + +template <typename stride_t> +METAL_FUNC stride_t elem_to_loc( + stride_t elem, + constant const int* shape, + constant const stride_t* strides, + int ndim) { + stride_t loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + return loc; +} + +// Non templated version to handle arbitrary dims +template <typename stride_t> +METAL_FUNC stride_t elem_to_loc( + uint3 elem, + constant const int* shape, + constant const stride_t* strides, + int ndim) { + stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2]; + for (int d = ndim - 3; d >= 0; --d) { + loc += (elem.z % shape[d]) * strides[d]; + elem.z /= shape[d]; + } + return loc; +} + + +METAL_FUNC ulong2 elem_to_loc_broadcast( + uint elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + int ndim) { + ulong loc_a{0}; + ulong loc_b{0}; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int pos_in_dim = (elem % shape[i]); + elem /= shape[i]; + loc_a += pos_in_dim * a_strides[i]; + loc_b += pos_in_dim * b_strides[i]; + } + return ulong2(loc_a, loc_b); +} + +METAL_FUNC ulong3 elem_to_loc_broadcast( + uint elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const size_t* c_strides, + int ndim) { + ulong loc_a{0}; + ulong loc_b{0}; + ulong loc_c{0}; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int pos_in_dim = (elem % shape[i]); + elem /= shape[i]; + loc_a += pos_in_dim * a_strides[i]; + loc_b += pos_in_dim * b_strides[i]; + loc_c += pos_in_dim * c_strides[i]; + } + return ulong3(loc_a, loc_b, loc_c); +} + + +// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h#L1 +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +constant bool has_batch [[function_constant(10)]]; + +constant bool use_out_source [[function_constant(100)]]; +constant bool do_axpby [[function_constant(110)]]; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +constant bool do_gather [[function_constant(300)]]; + +constant bool gather_bias = do_gather && use_out_source; + +// clang-format off +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device T* C [[buffer(2), function_constant(use_out_source)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], + const constant int* batch_shape [[buffer(6)]], + const constant size_t* batch_strides [[buffer(7)]], + const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], + const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], + const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], + const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], + const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], + const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on + // Pacifying compiler + (void)lid; + + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + // Find block + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + // Exit early if out of bounds + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + // Adjust for batch + + // Handle gather + if (do_gather) { + // Read indices + uint32_t indx_A, indx_B, indx_C; + + if (has_batch) { + const constant size_t* indx_A_bstrides = batch_strides; + const constant size_t* indx_B_bstrides = + batch_strides + params->batch_ndim; + + ulong2 indx_offsets = elem_to_loc_broadcast( + tid.z, + batch_shape, + indx_A_bstrides, + indx_B_bstrides, + params->batch_ndim); + indx_A = lhs_indices[indx_offsets.x]; + indx_B = rhs_indices[indx_offsets.y]; + + if (use_out_source) { + const constant size_t* indx_C_bstrides = + indx_B_bstrides + params->batch_ndim; + auto indx_offset_C = elem_to_loc( + tid.z, batch_shape, indx_C_bstrides, params->batch_ndim); + indx_C = C_indices[indx_offset_C]; + } + } else { + indx_A = lhs_indices[params->batch_stride_a * tid.z]; + indx_B = rhs_indices[params->batch_stride_b * tid.z]; + + if (use_out_source) { + indx_C = C_indices[addmm_params->batch_stride_c * tid.z]; + } + } + + // Translate indices to offsets + int batch_ndim_A = operand_batch_ndim.x; + const constant int* batch_shape_A = operand_shape; + const constant size_t* batch_strides_A = operand_strides; + A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A); + + int batch_ndim_B = operand_batch_ndim.y; + const constant int* batch_shape_B = batch_shape_A + batch_ndim_A; + const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A; + B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B); + + if (use_out_source) { + int batch_ndim_C = operand_batch_ndim.z; + const constant int* batch_shape_C = batch_shape_B + batch_ndim_B; + const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B; + C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C); + } + + } + + // Handle regular batch + else { + if (has_batch) { + const constant size_t* A_bstrides = batch_strides; + const constant size_t* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + if (use_out_source) { + const constant size_t* C_bstrides = B_bstrides + params->batch_ndim; + C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); + } + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + + if (use_out_source) { + C += addmm_params->batch_stride_c * tid.z; + } + } + } + + D += params->batch_stride_d * tid.z; + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + if (use_out_source) { + C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc; + } + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + // Prepare iterations + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + // Do unaligned K iterations first + if (!align_K) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + // Move loader source ahead to end + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + + const TransformAdd<AccumType, AccumType> epilogue_op_add( + addmm_params->alpha, addmm_params->beta); + const TransformAxpby<AccumType, AccumType> epilogue_op_axpby( + addmm_params->alpha, addmm_params->beta); + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (align_M && align_N) { + // Do gemm + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); + } else { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result(D, params->ldd); + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + const int leftover_bk = 0; + + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + // Do gemm + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment<true, true, true>{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); + } else { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result(D, params->ldd); + + } else if (align_N || tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment<false, true, true>{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + + } else if (align_M || tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment<true, false, true>{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + + } else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment<false, false, true>{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} + +#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + template [[host_name("gemm_" #tname "_" #iname "_" #oname "_" #bm "_" #bn "_" #bk "_" #wm "_" #wn)]] \ + [[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, float>( \ + const device itype *A [[buffer(0)]], \ + const device itype *B [[buffer(1)]], \ + const device itype *C [[buffer(2), function_constant(use_out_source)]], \ + device itype *D [[buffer(3)]], \ + const constant GEMMParams* params [[buffer(4)]], \ + const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], \ + const constant int* batch_shape [[buffer(6)]], \ + const constant size_t* batch_strides [[buffer(7)]], \ + const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], \ + const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], \ + const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], \ + const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], \ + const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], \ + const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + +instantiate_gemm_transpose_helper(f32, float, f32, float, 32, 32, 16, 2, 2) +instantiate_gemm_transpose_helper(f16, half, f16, half, 32, 32, 16, 2, 2) +#if defined(__HAVE_BFLOAT__) +instantiate_gemm_transpose_helper(bf16, bfloat, bf16, bfloat, 32, 32, 16, 2, 2) +#endif diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 30c454af..8b1adbde 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -329,7 +329,7 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { #[test] fn cast_f32() { - let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f64 = [1.0f64, 2.0, 3.0]; let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect(); let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); @@ -360,7 +360,7 @@ fn cast_f32() { #[test] fn cast_f16() { - let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f64 = [1.0f64, 2.0, 3.0]; let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect(); let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); @@ -391,7 +391,7 @@ fn cast_f16() { #[test] fn cast_bf16() { - let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f64 = [1.0f64, 2.0, 3.0]; let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect(); let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); @@ -422,7 +422,7 @@ fn cast_bf16() { #[test] fn cast_u32() { - let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f64 = [1.0f64, 2.0, 3.0]; let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect(); let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); @@ -453,7 +453,7 @@ fn cast_u32() { #[test] fn cast_u8() { - let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f64 = [1.0f64, 2.0, 3.0]; let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect(); let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); @@ -484,7 +484,7 @@ fn cast_u8() { #[test] fn cast_i64() { - let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f64 = [1.0f64, 2.0, 3.0]; let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect(); let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); @@ -911,7 +911,7 @@ fn softmax() { vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] ); - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] + let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] .iter() .map(|v| f16::from_f32(*v)) .collect::<Vec<_>>(); @@ -922,7 +922,7 @@ fn softmax() { vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338] ); - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] + let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] .iter() .map(|v| bf16::from_f32(*v)) .collect::<Vec<_>>(); @@ -1045,14 +1045,15 @@ fn where_cond_u32_f32() { assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); } +#[allow(clippy::too_many_arguments)] fn run_gemm<T: Clone>( name: &'static str, (b, m, n, k): (usize, usize, usize, usize), lhs: &[T], - lhs_stride: Vec<usize>, + lhs_stride: &[usize], lhs_offset: usize, rhs: &[T], - rhs_stride: Vec<usize>, + rhs_stride: &[usize], rhs_offset: usize, ) -> Vec<T> { let device = device(); @@ -1079,10 +1080,10 @@ fn run_gemm<T: Clone>( &kernels, name, (b, m, n, k), - &lhs_stride, + lhs_stride, lhs_offset, &lhs, - &rhs_stride, + rhs_stride, rhs_offset, &rhs, &output, @@ -1105,10 +1106,10 @@ fn gemm() { "sgemm", (b, m, n, k), &lhs, - lhs_stride, + &lhs_stride, 0, &rhs, - rhs_stride, + &rhs_stride, 0, ); assert_eq!( @@ -1125,10 +1126,10 @@ fn gemm() { "sgemm", (b, m, n, k), &lhs, - lhs_stride, + &lhs_stride, 0, &rhs, - rhs_stride, + &rhs_stride, 0, ); assert_eq!( @@ -1150,10 +1151,10 @@ fn gemm() { "sgemm", (1, m, n, k), &lhs, - lhs_stride, + &lhs_stride, 0, &rhs, - rhs_stride, + &rhs_stride, 12 * 4, ); assert_eq!( @@ -1172,10 +1173,10 @@ fn gemm() { "bgemm", (b, m, n, k), &lhs, - lhs_stride, + &lhs_stride, 0, &rhs, - rhs_stride, + &rhs_stride, 0, ); assert_eq!( @@ -1194,10 +1195,10 @@ fn gemm() { "hgemm", (b, m, n, k), &lhs, - lhs_stride, + &lhs_stride, 0, &rhs, - rhs_stride, + &rhs_stride, 0, ); assert_eq!( @@ -1206,6 +1207,204 @@ fn gemm() { ); } +#[allow(clippy::too_many_arguments)] +fn run_mlx_gemm<T: Clone>( + dtype: GemmDType, + (b, m, n, k): (usize, usize, usize, usize), + lhs: &[T], + lhs_stride: &[usize], + lhs_offset: usize, + rhs: &[T], + rhs_stride: &[usize], + rhs_offset: usize, +) -> Vec<T> { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + + let lhs = device.new_buffer_with_data( + lhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(lhs) as u64, + options, + ); + let rhs = device.new_buffer_with_data( + rhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(rhs) as u64, + options, + ); + let length = b * m * n; + let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); + call_mlx_gemm( + &device, + command_buffer, + &kernels, + dtype, + (b, m, n, k), + lhs_stride, + lhs_offset, + &lhs, + rhs_stride, + rhs_offset, + &rhs, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, length) +} + +fn mlx_vs_mfa_one(b: usize, m: usize, n: usize, k: usize, dtype: GemmDType) { + use rand::SeedableRng; + use rand_distr::Distribution; + + let mut rng = rand::rngs::StdRng::seed_from_u64(42424242); + let normal = rand_distr::Normal::new(0.0, 1.0).unwrap(); + + let lhs: Vec<_> = (0..b * m * k).map(|_| normal.sample(&mut rng)).collect(); + let rhs: Vec<_> = (0..b * n * k).map(|_| normal.sample(&mut rng)).collect(); + let v1: Vec<f32> = run_mlx_gemm( + dtype, + (b, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[k * n, n, 1], + 0, + ); + let v2: Vec<f32> = run_gemm( + "sgemm", + (b, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[k * n, n, 1], + 0, + ); + for (a, b) in v1.iter().zip(v2.iter()) { + let diff = (a - b).abs(); + assert_eq!((diff * 1e4).round(), 0.) + } +} + +#[test] +fn mlx_vs_mfa() { + mlx_vs_mfa_one(1, 32, 32, 25, GemmDType::F32); + mlx_vs_mfa_one(1, 128, 128, 100, GemmDType::F32); + mlx_vs_mfa_one(1, 256, 256, 256, GemmDType::F32); + mlx_vs_mfa_one(1, 192, 200, 75, GemmDType::F32); + mlx_vs_mfa_one(3, 27, 67, 64, GemmDType::F32); +} + +#[test] +fn mlx_gemm() { + let (b, m, n, k) = (1, 2, 4, 3); + let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect(); + let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect(); + let results = run_mlx_gemm( + GemmDType::F32, + (b, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[n * k, n, 1], + 0, + ); + assert_eq!( + approx(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); + + let (b, m, n, k) = (2, 2, 4, 3); + let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect(); + let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect(); + let results = run_mlx_gemm( + GemmDType::F32, + (b, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[n * k, n, 1], + 0, + ); + assert_eq!( + approx(results, 4), + vec![ + 20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0, + 518.0, 548.0, 578.0 + ] + ); + + // OFFSET + let (b, m, n, k) = (2, 2, 4, 3); + let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect(); + let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect(); + // Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32 + let results = run_mlx_gemm( + GemmDType::F32, + (1, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[n * k, n, 1], + 12 * 4, + ); + assert_eq!( + approx(results, 4), + vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0] + ); + + // bgemm sanity test + { + let (b, m, n, k) = (1, 2, 4, 3); + let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect(); + let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect(); + let results = run_mlx_gemm( + GemmDType::BF16, + (b, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[n * k, n, 1], + 0, + ); + assert_eq!( + approx_bf16(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); + } + + { + // hgemm sanity test + let (b, m, n, k) = (1, 2, 4, 3); + let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect(); + let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect(); + let results = run_mlx_gemm( + GemmDType::F16, + (b, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[n * k, n, 1], + 0, + ); + assert_eq!( + approx_f16(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); + } +} + fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> { let device = device(); let kernels = Kernels::new(); @@ -1280,7 +1479,7 @@ fn random() { variance.sqrt() } - let shape = vec![1024, 10]; + let shape = [1024, 10]; let length = shape.iter().product::<usize>(); let seed = 299792458; @@ -1636,7 +1835,7 @@ fn max_pool2d_f16() { &strides, "max_pool2d_f16", ); - let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0] + let expected = [5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0] .iter() .map(|v| half::f16::from_f32(*v)) .collect::<Vec<_>>(); @@ -1656,7 +1855,7 @@ fn max_pool2d_f16() { &strides, "max_pool2d_f16", ); - let expected = vec![5.0, 7.0, 13.0, 15.0] + let expected = [5.0, 7.0, 13.0, 15.0] .iter() .map(|v| half::f16::from_f32(*v)) .collect::<Vec<_>>(); @@ -1679,7 +1878,7 @@ fn max_pool2d_bf16() { &strides, "max_pool2d_bf16", ); - let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0] + let expected = [5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0] .iter() .map(|v| half::bf16::from_f32(*v)) .collect::<Vec<_>>(); @@ -1699,7 +1898,7 @@ fn max_pool2d_bf16() { &strides, "max_pool2d_bf16", ); - let expected = vec![5.0, 7.0, 13.0, 15.0] + let expected = [5.0, 7.0, 13.0, 15.0] .iter() .map(|v| half::bf16::from_f32(*v)) .collect::<Vec<_>>(); @@ -1818,7 +2017,7 @@ fn avg_pool2d_f16() { &strides, "avg_pool2d_f16", ); - let expected = vec![ + let expected = [ 2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000, ] .iter() @@ -1843,7 +2042,7 @@ fn avg_pool2d_bf16() { &strides, "avg_pool2d_bf16", ); - let expected = vec![ + let expected = [ 2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000, ] .iter() @@ -1981,14 +2180,14 @@ fn conv_transpose1d_f32() { #[test] fn conv_transpose1d_f16() { - let input: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0] + let input: Vec<f16> = [1.0, 2.0, 3.0, 4.0] .iter() .map(|v| f16::from_f32(*v)) .collect(); let input_shape = &[1, 1, 4]; let input_stride = &[4, 4, 1]; - let kernel: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0] + let kernel: Vec<f16> = [1.0, 2.0, 3.0, 4.0] .iter() .map(|v| f16::from_f32(*v)) .collect(); @@ -2009,7 +2208,7 @@ fn conv_transpose1d_f16() { "conv_transpose1d_f16", ); - let expected = vec![1., 4., 10., 20., 25., 24., 16.] + let expected = [1., 4., 10., 20., 25., 24., 16.] .iter() .map(|v| f16::from_f32(*v)) .collect::<Vec<_>>(); @@ -2018,14 +2217,14 @@ fn conv_transpose1d_f16() { #[test] fn conv_transpose1d_bf16() { - let input: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0] + let input: Vec<bf16> = [1.0, 2.0, 3.0, 4.0] .iter() .map(|v| bf16::from_f32(*v)) .collect(); let input_shape = &[1, 1, 4]; let input_stride = &[4, 4, 1]; - let kernel: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0] + let kernel: Vec<bf16> = [1.0, 2.0, 3.0, 4.0] .iter() .map(|v| bf16::from_f32(*v)) .collect(); @@ -2046,7 +2245,7 @@ fn conv_transpose1d_bf16() { "conv_transpose1d_bf16", ); - let expected = vec![1., 4., 10., 20., 25., 24., 16.] + let expected = [1., 4., 10., 20., 25., 24., 16.] .iter() .map(|v| bf16::from_f32(*v)) .collect::<Vec<_>>(); diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs index b42bcff0..2ddd610b 100644 --- a/candle-metal-kernels/src/utils.rs +++ b/candle-metal-kernels/src/utils.rs @@ -165,7 +165,7 @@ pub trait EncoderProvider { type Encoder<'a>: AsRef<metal::ComputeCommandEncoderRef> where Self: 'a; - fn encoder<'a>(&'a self) -> Self::Encoder<'a>; + fn encoder(&self) -> Self::Encoder<'_>; } pub struct WrappedEncoder<'a>(&'a ComputeCommandEncoderRef); @@ -178,7 +178,7 @@ impl<'a> Drop for WrappedEncoder<'a> { impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> { fn as_ref(&self) -> &metal::ComputeCommandEncoderRef { - &self.0 + self.0 } } @@ -186,7 +186,7 @@ impl EncoderProvider for &metal::CommandBuffer { type Encoder<'a> = WrappedEncoder<'a> where Self: 'a; - fn encoder<'a>(&'a self) -> Self::Encoder<'a> { + fn encoder(&self) -> Self::Encoder<'_> { WrappedEncoder(self.new_compute_command_encoder()) } } @@ -195,7 +195,7 @@ impl EncoderProvider for &metal::CommandBufferRef { type Encoder<'a> = WrappedEncoder<'a> where Self: 'a; - fn encoder<'a>(&'a self) -> Self::Encoder<'a> { + fn encoder(&self) -> Self::Encoder<'_> { WrappedEncoder(self.new_compute_command_encoder()) } } |