diff options
author | Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> | 2024-10-01 22:41:59 +0530 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-01 19:11:59 +0200 |
commit | a2bcc227df64b22cfbc54b5f96c995bf3a38c7bc (patch) | |
tree | 79d10359ccc57c6ad31f05f4b0cd8a3513af04ef /candle-metal-kernels | |
parent | def4c6cdeef78e437846efcb46a23006f539dee4 (diff) | |
download | candle-a2bcc227df64b22cfbc54b5f96c995bf3a38c7bc.tar.gz candle-a2bcc227df64b22cfbc54b5f96c995bf3a38c7bc.tar.bz2 candle-a2bcc227df64b22cfbc54b5f96c995bf3a38c7bc.zip |
Efficient implementation of `Tensor::ones()` for `metal` (#2512)
* WIP: hopefully better const impl
* with GPU
* More tests on
* Reverting primitive for
* Incorporating review changes - added check elem count check in kerner, using for call strategy
* rustfmt ran
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/fill.metal | 39 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 28 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 65 |
3 files changed, 132 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/fill.metal b/candle-metal-kernels/src/fill.metal new file mode 100644 index 00000000..35c3fe7a --- /dev/null +++ b/candle-metal-kernels/src/fill.metal @@ -0,0 +1,39 @@ +#include <metal_stdlib> + +using namespace metal; + +template<typename T> METAL_FUNC void fill_with( + device T *out, + constant float &value, + constant size_t &numel, + uint tid [[thread_position_in_grid]] +) { + if (tid >= numel) { + return; + } + out[tid] = static_cast<T>(value); +} + +#define FILL_OP(NAME, T) \ +kernel void fill_##NAME( \ + device T *out, \ + constant float &value, \ + constant size_t &numel, \ + uint tid [[thread_position_in_grid]] \ +) { \ + fill_with<T>(out, value, numel, tid); \ +} \ + + +#define FILL_OPS(NAME, T) \ +FILL_OP(NAME, T) \ + +FILL_OPS(u8, uchar) +FILL_OPS(u32, uint) +FILL_OPS(i64, long) +FILL_OPS(f16, half) +FILL_OPS(f32, float) + +#if __METAL_VERSION__ >= 310 +FILL_OPS(bf16, bfloat) +#endif diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a595b2bd..a270bb28 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -14,6 +14,7 @@ const AFFINE: &str = include_str!("affine.metal"); const BINARY: &str = include_str!("binary.metal"); const CAST: &str = include_str!("cast.metal"); const CONV: &str = include_str!("conv.metal"); +const FILL: &str = include_str!("fill.metal"); const INDEXING: &str = include_str!("indexing.metal"); // Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); @@ -31,6 +32,7 @@ pub enum Source { Binary, Cast, Conv, + Fill, Gemm, Indexing, Mfa, @@ -196,6 +198,7 @@ impl Kernels { Source::Binary => BINARY, Source::Cast => CAST, Source::Conv => CONV, + Source::Fill => FILL, Source::Gemm => MLX_GEMM, Source::Indexing => INDEXING, Source::Quantized => QUANTIZED, @@ -2357,5 +2360,30 @@ pub fn call_mlx_gemm( Ok(()) } +pub fn call_const_fill( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + length: usize, + output: &Buffer, + v: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Fill, name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (output, v, length)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + + Ok(()) +} + #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 8b1adbde..f37ab5bb 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1,6 +1,7 @@ use super::*; use half::{bf16, f16}; use metal::MTLResourceOptions; +use rand::Rng; fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> { let ptr = buffer.contents() as *const T; @@ -2307,3 +2308,67 @@ fn conv_transpose1d_u32() { let expected = vec![1, 4, 10, 20, 25, 24, 16]; assert_eq!(results, expected); } + +fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> { + let dev = device(); + let kernels = Kernels::new(); + let command_queue = dev.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let buffer = dev.new_buffer( + (len * std::mem::size_of::<T>()) as u64, + MTLResourceOptions::StorageModePrivate, + ); + + call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec::<T>(&buffer, len) +} + +#[test] +fn const_fill() { + let fills = [ + "fill_u8", + "fill_u32", + "fill_i64", + "fill_f16", + "fill_bf16", + "fill_f32", + ]; + + for name in fills { + let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16); + let value = rand::thread_rng().gen_range(1. ..19.); + + match name { + "fill_u8" => { + let v = constant_fill::<u8>(name, len, value); + assert_eq!(v, vec![value as u8; len]) + } + "fill_u32" => { + let v = constant_fill::<u32>(name, len, value); + assert_eq!(v, vec![value as u32; len]) + } + "fill_i64" => { + let v = constant_fill::<i64>(name, len, value); + assert_eq!(v, vec![value as i64; len]) + } + "fill_f16" => { + let v = constant_fill::<f16>(name, len, value); + assert_eq!(v, vec![f16::from_f32(value); len]) + } + "fill_bf16" => { + let v = constant_fill::<bf16>(name, len, value); + assert_eq!(v, vec![bf16::from_f32(value); len]) + } + "fill_f32" => { + let v = constant_fill::<f32>(name, len, value); + assert_eq!(v, vec![value; len]) + } + _ => unimplemented!(), + }; + } +} |