diff options
Diffstat (limited to 'candle-metal-kernels/src')
-rw-r--r-- | candle-metal-kernels/src/affine.metal | 18 | ||||
-rw-r--r-- | candle-metal-kernels/src/cast.metal | 18 | ||||
-rw-r--r-- | candle-metal-kernels/src/indexing.metal | 9 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 303 | ||||
-rw-r--r-- | candle-metal-kernels/src/reduce.metal | 156 | ||||
-rw-r--r-- | candle-metal-kernels/src/ternary.metal | 3 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 158 | ||||
-rw-r--r-- | candle-metal-kernels/src/unary.metal | 48 |
8 files changed, 466 insertions, 247 deletions
diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index e5f0a841..a08bfbc0 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -33,6 +33,24 @@ kernel void FN_NAME( \ const TYPENAME a = TYPENAME(add); \ output[id] = input[id] * m + a; \ } \ +kernel void FN_NAME##_strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant float &mul, \ + constant float &add, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint id [[ thread_position_in_grid ]] \ +) { \ + if (id >= dim) { \ + return; \ + } \ + const TYPENAME m = TYPENAME(mul); \ + const TYPENAME a = TYPENAME(add); \ + output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \ +} \ AFFINE(affine_float, float) AFFINE(affine_half, half) diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index d1788253..4398e9d4 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -23,12 +23,12 @@ kernel void FN_NAME( \ constant size_t &dim, \ device const LEFT_TYPENAME *input, \ device RIGHT_TYPENAME *output, \ - uint thread_position_in_grid [[ thread_position_in_grid ]] \ + uint tid [[ thread_position_in_grid ]] \ ) { \ - if (thread_position_in_grid >= dim) { \ + if (tid >= dim) { \ return; \ } \ - output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \ + output[tid] = RIGHT_TYPENAME(input[tid]); \ } \ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -37,15 +37,19 @@ kernel void FN_NAME_STRIDED( \ constant size_t *strides, \ device const LEFT_TYPENAME *input, \ device RIGHT_TYPENAME *output, \ - uint i [[ thread_position_in_grid ]] \ + uint tid [[ thread_position_in_grid ]] \ ) { \ - if (i >= dim) { \ + if (tid >= dim) { \ return; \ } \ - output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \ + output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \ } \ -CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float) +CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) +CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t) +CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) +CAST(cast_f16_f32, cast_f16_f32_strided, half, float) +CAST(cast_f32_f16, cast_f32_f16_strided, float, half) #if __METAL_VERSION__ >= 310 #endif diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 444fa322..312b27c7 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -16,16 +16,16 @@ kernel void NAME( \ if (gid >= dst_size) { \ return; \ } \ - const size_t id_i = gid / right_size / left_size; \ + const size_t id_i = (gid / right_size) % ids_size; \ + const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \ const size_t right_rank_i = gid % right_size; \ - const size_t left_rank_i = gid % left_size; \ + const size_t left_rank_i = gid / right_size / ids_size; \ /* \ // Force prevent out of bounds indexing \ // since there doesn't seem to be a good way to force crash \ // No need to check for zero we're only allowing unsized. \ */ \ - const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \ - const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \ + const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \ output[gid] = input[src_i]; \ } @@ -75,6 +75,7 @@ kernel void FN_NAME( \ INDEX_OP(is_u32_f32, uint, float) +INDEX_OP(is_u32_f16, uint, half) #if __METAL_VERSION__ >= 310 diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 5a6bd41b..a0b852a4 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,6 +1,6 @@ use metal::{ - Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor, - ComputePipelineState, Device, Function, Library, MTLSize, + Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, + Device, Function, Library, MTLSize, }; use std::collections::HashMap; use std::ffi::c_void; @@ -59,8 +59,8 @@ impl<T> EncoderParam for &[T] { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { encoder.set_bytes( position, - (core::mem::size_of::<T>() * data.len()) as u64, - data.as_ptr() as *const T as *const c_void, + core::mem::size_of_val(data) as u64, + data.as_ptr() as *const c_void, ); } } @@ -111,13 +111,7 @@ macro_rules! ops{ ($($name:ident),+) => { pub mod contiguous { - #[derive(Clone, Copy)] - pub struct Kernel(pub(crate) &'static str); - impl std::fmt::Display for Kernel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } - } + pub struct Kernel(pub &'static str); $( pub mod $name { use super::Kernel; @@ -126,16 +120,18 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat")); } )+ + pub mod copy { + use super::Kernel; + pub const FLOAT: Kernel = Kernel("copy_float"); + pub const HALF: Kernel = Kernel("copy_half"); + pub const BFLOAT: Kernel = Kernel("copy_bfloat"); + pub const U32: Kernel = Kernel("copy_u32"); + pub const U8: Kernel = Kernel("copy_u8"); + } } pub mod strided { - #[derive(Clone, Copy)] - pub struct Kernel(pub(crate) &'static str); - impl std::fmt::Display for Kernel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } - } + pub struct Kernel(pub &'static str); $( pub mod $name { use super::Kernel; @@ -144,12 +140,20 @@ macro_rules! ops{ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided")); } )+ + pub mod copy { + use super::Kernel; + pub const FLOAT: Kernel = Kernel("copy_float_strided"); + pub const HALF: Kernel = Kernel("copy_half_strided"); + pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided"); + pub const U32: Kernel = Kernel("copy_u32_strided"); + pub const U8: Kernel = Kernel("copy_u8_strided"); + } } }; } pub mod unary { - ops!(cos, sin, exp, sqr, sqrt, neg, copy, log); + ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf); } pub mod binary { ops!(add, sub, mul, div); @@ -161,8 +165,12 @@ pub enum MetalKernelError { LockError(String), #[error("Error while loading library: {0}")] LoadLibraryError(String), - #[error("Error while loading function: {0}")] + #[error("Error while loading function: {0:?}")] LoadFunctionError(String), + #[error("Failed to create compute function")] + FailedToCreateComputeFunction, + #[error("Failed to create pipeline")] + FailedToCreatePipeline(String), } impl<T> From<std::sync::PoisonError<T>> for MetalKernelError { @@ -173,19 +181,22 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError { type KernelMap<T> = HashMap<&'static str, T>; type Libraries = HashMap<Source, Library>; -type Functions = KernelMap<Function>; +type Pipelines = KernelMap<ComputePipelineState>; #[derive(Debug, Default)] pub struct Kernels { libraries: RwLock<Libraries>, - funcs: RwLock<Functions>, + pipelines: RwLock<Pipelines>, } impl Kernels { pub fn new() -> Self { let libraries = RwLock::new(Libraries::new()); - let funcs = RwLock::new(Functions::new()); - Self { libraries, funcs } + let pipelines = RwLock::new(Pipelines::new()); + Self { + libraries, + pipelines, + } } fn get_library_source(&self, source: Source) -> &'static str { @@ -218,22 +229,43 @@ impl Kernels { } } - pub fn load_function( + fn load_function( &self, device: &Device, source: Source, name: &'static str, ) -> Result<Function, MetalKernelError> { - let mut funcs = self.funcs.write()?; - if let Some(func) = funcs.get(name) { - Ok(func.clone()) + let func = self + .load_library(device, source)? + .get_function(name, None) + .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; + Ok(func) + // let mut funcs = self.funcs.write()?; + // if let Some(func) = funcs.get(name) { + // Ok(func.clone()) + // } else { + // funcs.insert(name, func.clone()); + // Ok(func) + // } + } + + pub fn load_pipeline( + &self, + device: &Device, + source: Source, + name: &'static str, + ) -> Result<ComputePipelineState, MetalKernelError> { + let mut pipelines = self.pipelines.write()?; + if let Some(pipeline) = pipelines.get(name) { + Ok(pipeline.clone()) } else { - let func = self - .load_library(device, source)? - .get_function(name, None) - .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; - funcs.insert(name, func.clone()); - Ok(func) + let func = self.load_function(device, source, name)?; + let pipeline = device + .new_compute_pipeline_state_with_function(&func) + .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?; + pipelines.insert(name, pipeline.clone()); + + Ok(pipeline) } } } @@ -246,18 +278,9 @@ pub fn call_unary_contiguous( kernel_name: unary::contiguous::Kernel, length: usize, input: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Unary, kernel_name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); - + let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -279,18 +302,10 @@ pub fn call_unary_strided( input: &Buffer, strides: &[usize], offset: usize, - output: &mut Buffer, + output: &Buffer, output_offset: usize, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Unary, name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); @@ -326,17 +341,9 @@ pub fn call_binary_contiguous( length: usize, left: &Buffer, right: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Binary, kernel_name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -363,17 +370,9 @@ pub fn call_binary_strided( right_input: &Buffer, right_strides: &[usize], right_offset: usize, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Binary, name.0)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?; let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); @@ -411,22 +410,52 @@ pub fn call_cast_contiguous( kernel_name: &'static str, length: usize, input: &Buffer, - output: &mut Buffer, + input_offset: usize, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Cast, kernel_name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); + let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (length, (input, input_offset), output)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_cast_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + shape: &[usize], + input: &Buffer, + input_strides: &[usize], + input_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, input, output)); + let length: usize = shape.iter().product(); + + set_params!( + encoder, + ( + length, + shape.len(), + shape, + input_strides, + (input, input_offset), + output + ) + ); let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); @@ -435,7 +464,6 @@ pub fn call_cast_contiguous( Ok(()) } -#[allow(clippy::too_many_arguments)] pub fn call_reduce_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -444,24 +472,19 @@ pub fn call_reduce_contiguous( length: usize, out_length: usize, input: &Buffer, - output: &mut Buffer, + input_offset: usize, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Reduce, kernel_name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); - + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, elements_to_sum, input, output)); + set_params!( + encoder, + (length, elements_to_sum, (input, input_offset), output) + ); let thread_group_count = MTLSize { width: out_length as u64, @@ -495,18 +518,9 @@ pub fn call_last_softmax( length: usize, elements_to_sum: usize, input: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Reduce, kernel_name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); - + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -542,21 +556,14 @@ pub fn call_affine( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, + name: &'static str, size: usize, input: &Buffer, - output: &mut Buffer, + output: &Buffer, mul: f32, add: f32, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Affine, "affine_float")?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -570,6 +577,45 @@ pub fn call_affine( } #[allow(clippy::too_many_arguments)] +pub fn call_affine_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + input: &Buffer, + input_stride: &[usize], + input_offset: usize, + output: &Buffer, + mul: f32, + add: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; + let size: usize = shape.iter().product(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + size, + shape.len(), + shape, + input_stride, + mul, + add, + (input, input_offset), + output + ) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + pub fn call_where_cond_strided( device: &Device, command_buffer: &CommandBufferRef, @@ -582,17 +628,9 @@ pub fn call_where_cond_strided( (left_stride, left_offset): (&[usize], usize), right: &Buffer, (right_stride, right_offset): (&[usize], usize), - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Ternary, name)?; - let pipeline_state_descriptor = ComputePipelineDescriptor::new(); - pipeline_state_descriptor.set_compute_function(Some(&func)); - - let pipeline = device - .new_compute_pipeline_state_with_function( - pipeline_state_descriptor.compute_function().unwrap(), - ) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); @@ -634,17 +672,14 @@ pub fn call_index_select( dim: usize, input: &Buffer, ids: &Buffer, - output: &mut Buffer, + output: &Buffer, ) -> Result<(), MetalKernelError> { let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); let src_dim_size = shape[dim]; let dst_el = ids_size * left_size * right_size; - let func = kernels.load_function(device, Source::Indexing, name)?; - let pipeline = device - .new_compute_pipeline_state_with_function(&func) - .unwrap(); + let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let encoder = command_buffer.new_compute_command_encoder(); diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index c6984474..867877fb 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -1,6 +1,8 @@ #include <metal_stdlib> using namespace metal; +#define MAX(x, y) ((x) > (y) ? (x) : (y)) + METAL_FUNC uint get_strided_index( uint idx, constant size_t &num_dims, @@ -16,18 +18,18 @@ METAL_FUNC uint get_strided_index( return strided_i; } -constant int THREADGROUP_SIZE = 256; +constant int THREADGROUP_SIZE = 1024; -# define REDUCE(FN, NAME, TYPENAME) \ +# define REDUCE(FN, NAME, T) \ kernel void NAME( \ constant size_t &src_numel, \ constant size_t &el_to_sum_per_block, \ - device const TYPENAME *src, \ - device TYPENAME *dst, \ + device const T *src, \ + device T *dst, \ uint id [[ thread_position_in_grid ]], \ uint tid [[ thread_index_in_threadgroup ]], \ uint dst_id [[ threadgroup_position_in_grid ]], \ - uint blockDim [[ threads_per_threadgroup ]] \ + uint block_dim [[ threads_per_threadgroup ]] \ ) { \ \ threadgroup float shared_memory[THREADGROUP_SIZE]; \ @@ -45,10 +47,10 @@ kernel void NAME( \ // TODO: Fast version for the contiguous case. \ // size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ */ \ - TYPENAME x = shared_memory[tid]; \ - TYPENAME y = src[idx]; \ + T x = shared_memory[tid]; \ + T y = src[idx]; \ shared_memory[tid] = FN; \ - idx += blockDim; \ + idx += block_dim; \ } \ \ threadgroup_barrier(mem_flags::mem_none); \ @@ -56,10 +58,10 @@ kernel void NAME( \ /* \ // reduction in shared memory \ */ \ - for (uint s = blockDim / 2; s > 0; s >>= 1) { \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ if (tid < s) { \ - TYPENAME x = shared_memory[tid]; \ - TYPENAME y = shared_memory[tid + s]; \ + T x = shared_memory[tid]; \ + T y = shared_memory[tid + s]; \ shared_memory[tid] = FN; \ } \ threadgroup_barrier(mem_flags::mem_none); \ @@ -68,72 +70,74 @@ kernel void NAME( \ dst[dst_id] = shared_memory[0]; \ } \ -kernel void softmax_float( - constant size_t &src_numel, - constant size_t &el_to_sum_per_block, - device const float *src, - device float *dst, - uint id [[ thread_position_in_grid ]], - uint tid [[ thread_index_in_threadgroup ]], - uint dst_id [[ threadgroup_position_in_grid ]], - uint blockDim [[ threads_per_threadgroup ]] -) { - - threadgroup float shared_memory[THREADGROUP_SIZE]; - - shared_memory[tid] = -INFINITY; - // Elements summed in this block range from dst_id * el_to_sum_per_block - // to (dst_id + 1) * el_to_sum_per_block. - size_t start_idx = dst_id * el_to_sum_per_block; - size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); - size_t idx = start_idx + tid; - - while (idx < stop_idx) { - // TODO: Fast version for the contiguous case. - shared_memory[tid] = max(shared_memory[tid], src[idx]); - idx += blockDim; - } - - threadgroup_barrier(mem_flags::mem_none); - - // reduction in shared memory - for (uint s = blockDim / 2; s > 0; s >>= 1) { - if (tid < s) { - shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]); - } - threadgroup_barrier(mem_flags::mem_none); - } - - float max = shared_memory[0]; - - shared_memory[tid] = 0; - - // Restart - idx = start_idx + tid; - while (idx < stop_idx) { - // TODO: Fast version for the contiguous case. - const float val = exp(src[idx] - max); - dst[idx] = val; - shared_memory[tid] += val; - idx += blockDim; - } - // reduction in shared memory - for (uint s = blockDim / 2; s > 0; s >>= 1) { - if (tid < s) { - shared_memory[tid] += shared_memory[tid + s]; - } - threadgroup_barrier(mem_flags::mem_none); - } - - const float inv_acc = 1/shared_memory[0]; - idx = start_idx + tid; - while (idx < stop_idx) { - dst[idx] *= inv_acc; - idx += blockDim; - } -} - REDUCE(x + y, fast_sum_float, float) REDUCE(x * y, fast_mul_float, float) REDUCE(max(x, y), fast_max_float, float) + +#define SOFTMAX(NAME, T) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup float shared_memory[THREADGROUP_SIZE]; \ + shared_memory[tid] = -INFINITY; \ + size_t start_idx = dst_id * el_to_sum_per_block; \ + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ + size_t idx = start_idx + tid; \ + \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + \ + while (idx < stop_idx) { \ + shared_memory[tid] = MAX(shared_memory[tid], src[idx]); \ + idx += block_dim; \ + } \ + \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s) { \ + shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \ + } \ + } \ + \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + \ + float _max = shared_memory[0]; \ + \ + shared_memory[tid] = 0; \ + \ + idx = start_idx + tid; \ + while (idx < stop_idx) { \ + const T val = T(exp(src[idx] - _max)); \ + dst[idx] = val; \ + shared_memory[tid] += val; \ + idx += block_dim; \ + } \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s) { \ + shared_memory[tid] += shared_memory[tid + s]; \ + } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ + } \ + \ + const T inv_acc = T(1/shared_memory[0]); \ + idx = start_idx + tid; \ + while (idx < stop_idx) { \ + dst[idx] *= inv_acc; \ + idx += block_dim; \ + } \ +} \ + +SOFTMAX(softmax_float, float) +SOFTMAX(softmax_half, half) +#if __METAL_VERSION__ >= 310 +SOFTMAX(softmax_bfloat, bfloat) +#endif diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal index 0945b355..1f9cb38a 100644 --- a/candle-metal-kernels/src/ternary.metal +++ b/candle-metal-kernels/src/ternary.metal @@ -32,6 +32,9 @@ kernel void FN_NAME( \ device TYPENAME *out ,\ uint i [[ thread_position_in_grid ]] \ ) { \ + if (i >= numel){ \ + return; \ + } \ uint strided_i = get_strided_index(i, num_dims, dims, strides); \ uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \ uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \ diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 2330d48d..66dc8d01 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1,5 +1,5 @@ use super::*; -use half::f16; +use half::{bf16, f16}; use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer { @@ -23,13 +23,18 @@ fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> { v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() } +fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> { + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() +} + fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> { let device = device(); let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); call_unary_contiguous( &device, command_buffer, @@ -37,7 +42,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> { name, v.len(), &input, - &mut output, + &output, ) .unwrap(); command_buffer.commit(); @@ -53,7 +58,7 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V let options = MTLResourceOptions::StorageModeManaged; let left = new_buffer(&device, x); let right = new_buffer(&device, y); - let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options); + let output = device.new_buffer(std::mem::size_of_val(x) as u64, options); call_binary_contiguous( &device, command_buffer, @@ -62,7 +67,7 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V x.len(), &left, &right, - &mut output, + &output, ) .unwrap(); command_buffer.commit(); @@ -81,7 +86,7 @@ fn run_strided<T: Clone>( let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); let kernels = Kernels::new(); call_unary_strided( &device, @@ -92,7 +97,7 @@ fn run_strided<T: Clone>( &input, strides, offset, - &mut output, + &output, 0, ) .unwrap(); @@ -220,7 +225,9 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let options = MTLResourceOptions::StorageModeManaged; + let size = (v.len() * std::mem::size_of::<U>()) as u64; + let output = device.new_buffer(size, options); call_cast_contiguous( &device, @@ -229,7 +236,8 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { name, v.len(), &input, - &mut output, + 0, + &output, ) .unwrap(); command_buffer.commit(); @@ -245,11 +253,17 @@ fn cast_u32_f32() { assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]); assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]); + let v = vec![1.0f32, 2.0, 3.0]; + let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect(); + let results: Vec<f32> = cast(&input, "cast_f16_f32"); + assert_eq!(results, vec![1.0f32, 2.0, 3.0]); + let v = vec![1.0f32; 10_000]; - let results = run(&v, unary::contiguous::cos::FLOAT); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!(approx(results, 4), vec![0.5403; 10_000]); - assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); + let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect(); + let results: Vec<f32> = cast(&input, "cast_f16_f32"); + assert_eq!(results.len(), 10_000); + assert_eq!(&results[..10], vec![1.0f32; 10]); + assert_eq!(results, vec![1.0f32; 10_000]); } fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { @@ -259,7 +273,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); let size = v.len(); @@ -267,9 +281,45 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { &device, command_buffer, &kernels, + "affine_float", size, &input, - &mut output, + &output, + mul as f32, + add as f32, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::<T>(v.len()) +} + +fn _run_affine_strided<T: Clone>( + v: &[T], + shape: &[usize], + strides: &[usize], + mul: f64, + add: f64, +) -> Vec<T> { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let input = new_buffer(&device, v); + let output = new_buffer(&device, v); + + call_affine_strided( + &device, + command_buffer, + &kernels, + "affine_float", + shape, + &input, + strides, + 0, + &output, mul as f32, add as f32, ) @@ -295,6 +345,16 @@ fn affine() { assert_eq!(result, vec![2.6; 40_000]); } +// #[test] +// fn affine_strided() { +// let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; +// let mul = 1.5; +// let add = 1.1; +// let result = run_affine_(&input, mul, add); +// assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]); + +// } + #[test] fn index_select() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; @@ -313,7 +373,26 @@ fn index_select() { result, vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] ); +} + +#[test] +fn index_select_f16() { + let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + .into_iter() + .map(|x| f16::from_f32(x)) + .collect(); + let shape = [5, 2]; + let ids = [0u32, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!( + approx_f16(result, 4), + vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] + ); +} +#[test] +fn index_select_dim1() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [5, 2]; let ids = [0u32, 1, 0]; @@ -321,7 +400,7 @@ fn index_select() { let result = run_index_select(&embedding, &shape, &ids, dim); assert_eq!( result, - vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] + vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0] ); } @@ -341,20 +420,26 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>( let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); let dst_el = ids.len() * left_size * right_size; - let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + + let name = match core::mem::size_of::<T>() { + 4 => "is_u32_f32", + 2 => "is_u32_f16", + _ => unimplemented!(), + }; let kernels = Kernels::new(); call_index_select( &device, &command_buffer, &kernels, - "is_u32_f32", + name, shape, ids.len(), dim, &embeddings_buffer, &ids_buffer, - &mut dst_buffer, + &dst_buffer, ) .unwrap(); @@ -451,7 +536,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T let input = new_buffer(&device, v); let options = MTLResourceOptions::StorageModeManaged; - let mut output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options); + let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options); call_reduce_contiguous( &device, command_buffer, @@ -460,7 +545,8 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T v.len(), out_length, &input, - &mut output, + 0, + &output, ) .unwrap(); command_buffer.commit(); @@ -475,7 +561,7 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); call_last_softmax( &device, command_buffer, @@ -484,7 +570,7 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta v.len(), last_dim, &input, - &mut output, + &output, ) .unwrap(); command_buffer.commit(); @@ -536,6 +622,28 @@ fn softmax() { approx(results, 4), vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] ); + + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect::<Vec<_>>(); + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_half"); + assert_eq!( + approx_f16(results, 4), + vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338] + ); + + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::<Vec<_>>(); + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_bfloat"); + assert_eq!( + approx_bf16(results, 4), + vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328] + ); } fn run_where_cond<I: Clone, T: Clone>( @@ -571,7 +679,7 @@ fn run_where_cond<I: Clone, T: Clone>( options, ); - let mut output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); + let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); call_where_cond_strided( &device, command_buffer, @@ -584,7 +692,7 @@ fn run_where_cond<I: Clone, T: Clone>( (&left_stride, left_offset), &right, (&cond_stride, cond_offset), - &mut output, + &output, ) .unwrap(); command_buffer.commit(); diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index eb6424e8..88139af9 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -1,4 +1,7 @@ #include <metal_stdlib> +#include <metal_math> +# +using namespace metal; METAL_FUNC uint get_strided_index( uint idx, @@ -17,10 +20,39 @@ METAL_FUNC uint get_strided_index( template <typename T> METAL_FUNC T sqr(T in){ return in * in; } template <typename T> METAL_FUNC T neg(T in){ return -in; } +template <typename T> METAL_FUNC T erf(T in){ + float x = (float) in; + // constants + float a1 = 0.254829592; + float a2 = -0.284496736; + float a3 = 1.421413741; + float a4 = -1.453152027; + float a5 = 1.061405429; + float p = 0.3275911; + + // Save the sign of x + int sign = 1; + if (x < 0) + sign = -1; + x = fabs(x); + + // A&S formula 7.1.26 + float t = 1.0/(1.0 + p*x); + float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x); + + return T(sign*y); +} template <typename T> METAL_FUNC T id(T in){ return in; } +template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); } +template <typename T> METAL_FUNC T gelu(T x){ + T x_sq = x * x; + T x_cube = x_sq * x; + T alpha = x + static_cast<T>(0.044715) * x_cube; + T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha); + return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta))); +} -using namespace metal; #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ kernel void FN_NAME( \ @@ -64,8 +96,16 @@ UNARY_OP(sqrt) UNARY_OP(neg) UNARY_OP(exp) UNARY_OP(log) +UNARY_OP(gelu) +UNARY_OP(ceil) +UNARY_OP(floor) +UNARY_OP(round) +UNARY_OP(gelu_erf) +UNARY_OP(erf) UNARY(id, float, copy_float, copy_float_strided) UNARY(id, half, copy_half, copy_half_strided) +UNARY(id, uint8_t, copy_u8, copy_u8_strided) +UNARY(id, uint32_t, copy_u32, copy_u32_strided) #if __METAL_VERSION__ >= 310 BFLOAT_UNARY_OP(cos) @@ -75,6 +115,12 @@ BFLOAT_UNARY_OP(sqrt) BFLOAT_UNARY_OP(neg) BFLOAT_UNARY_OP(exp) BFLOAT_UNARY_OP(log) +BFLOAT_UNARY_OP(gelu) +BFLOAT_UNARY_OP(ceil) +BFLOAT_UNARY_OP(floor) +BFLOAT_UNARY_OP(round) +BFLOAT_UNARY_OP(gelu_erf) +BFLOAT_UNARY_OP(erf) UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided) #endif |