diff options
Diffstat (limited to 'candle-metal-kernels/src')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 42 |
1 files changed, 11 insertions, 31 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 7288216a..6c2e5f2b 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,4 +1,3 @@ -#![allow(clippy::too_many_arguments)] use metal::{ Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor, ComputePipelineState, Device, Function, Library, MTLSize, @@ -156,14 +155,6 @@ pub mod binary { ops!(add, sub, mul, div); } -// static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| { -// let mut l = HashMap::new(); -// l.insert("affine", AFFINE); -// l.insert("indexing", INDEXING); -// l.insert("unary", UNARY); -// l -// }); -// #[derive(thiserror::Error, Debug)] pub enum MetalKernelError { #[error("Could not lock kernel map: {0}")] @@ -197,21 +188,7 @@ impl Kernels { Self { libraries, funcs } } - // pub fn init(device: &Device) -> Result<Self, MetalKernelError> { - // let kernels = Self::new(); - // kernels.load_libraries(device)?; - // Ok(kernels) - // } - - // fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> { - // for name in LIBRARY_SOURCES.keys() { - // self.load_library(device, name)?; - // } - // Ok(()) - // } - fn get_library_source(&self, source: Source) -> &'static str { - // LIBRARY_SOURCES.get(name).cloned() match source { Source::Affine => AFFINE, Source::Unary => UNARY, @@ -261,6 +238,7 @@ impl Kernels { } } +#[allow(clippy::too_many_arguments)] pub fn call_unary_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -270,8 +248,6 @@ pub fn call_unary_contiguous( input: &Buffer, output: &mut Buffer, ) -> Result<(), MetalKernelError> { - // println!("Kernel {:?}", kernel_name.0); - // assert_eq!(input.length(), output.length()); let func = kernels.load_function(device, Source::Unary, kernel_name.0)?; let pipeline_state_descriptor = ComputePipelineDescriptor::new(); pipeline_state_descriptor.set_compute_function(Some(&func)); @@ -292,6 +268,8 @@ pub fn call_unary_contiguous( encoder.end_encoding(); Ok(()) } + +#[allow(clippy::too_many_arguments)] pub fn call_unary_strided( device: &Device, command_buffer: &CommandBufferRef, @@ -339,6 +317,7 @@ pub fn call_unary_strided( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_binary_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -349,8 +328,6 @@ pub fn call_binary_contiguous( right: &Buffer, output: &mut Buffer, ) -> Result<(), MetalKernelError> { - // println!("Kernel {:?}", kernel_name.0); - // assert_eq!(input.length(), output.length()); let func = kernels.load_function(device, Source::Binary, kernel_name.0)?; let pipeline_state_descriptor = ComputePipelineDescriptor::new(); pipeline_state_descriptor.set_compute_function(Some(&func)); @@ -373,6 +350,7 @@ pub fn call_binary_contiguous( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_binary_strided( device: &Device, command_buffer: &CommandBufferRef, @@ -425,6 +403,7 @@ pub fn call_binary_strided( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_cast_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -434,8 +413,6 @@ pub fn call_cast_contiguous( input: &Buffer, output: &mut Buffer, ) -> Result<(), MetalKernelError> { - // println!("Kernel {:?}", kernel_name.0); - // assert_eq!(input.length(), output.length()); let func = kernels.load_function(device, Source::Cast, kernel_name)?; let pipeline_state_descriptor = ComputePipelineDescriptor::new(); pipeline_state_descriptor.set_compute_function(Some(&func)); @@ -458,6 +435,7 @@ pub fn call_cast_contiguous( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_reduce_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -508,6 +486,7 @@ pub fn call_reduce_contiguous( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_last_softmax( device: &Device, command_buffer: &CommandBufferRef, @@ -543,7 +522,6 @@ pub fn call_last_softmax( let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), - // (elements_to_sum as u64 + 2 - 1) / 2, elements_to_sum as u64, ) .next_power_of_two(); @@ -559,6 +537,7 @@ pub fn call_last_softmax( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_affine( device: &Device, command_buffer: &CommandBufferRef, @@ -590,6 +569,7 @@ pub fn call_affine( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_where_cond_strided( device: &Device, command_buffer: &CommandBufferRef, @@ -643,6 +623,7 @@ pub fn call_where_cond_strided( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_index_select( device: &Device, command_buffer: &CommandBufferRef, @@ -813,7 +794,6 @@ mod tests { #[test] fn cos_f32_strided() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - // Shape = [6], strides = [1]; let shape = vec![6]; let strides = vec![1]; let offset = 0; |