summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r--candle-metal-kernels/src/lib.rs42
1 files changed, 11 insertions, 31 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 7288216a..6c2e5f2b 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1,4 +1,3 @@
-#![allow(clippy::too_many_arguments)]
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor,
ComputePipelineState, Device, Function, Library, MTLSize,
@@ -156,14 +155,6 @@ pub mod binary {
ops!(add, sub, mul, div);
}
-// static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
-// let mut l = HashMap::new();
-// l.insert("affine", AFFINE);
-// l.insert("indexing", INDEXING);
-// l.insert("unary", UNARY);
-// l
-// });
-//
#[derive(thiserror::Error, Debug)]
pub enum MetalKernelError {
#[error("Could not lock kernel map: {0}")]
@@ -197,21 +188,7 @@ impl Kernels {
Self { libraries, funcs }
}
- // pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
- // let kernels = Self::new();
- // kernels.load_libraries(device)?;
- // Ok(kernels)
- // }
-
- // fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> {
- // for name in LIBRARY_SOURCES.keys() {
- // self.load_library(device, name)?;
- // }
- // Ok(())
- // }
-
fn get_library_source(&self, source: Source) -> &'static str {
- // LIBRARY_SOURCES.get(name).cloned()
match source {
Source::Affine => AFFINE,
Source::Unary => UNARY,
@@ -261,6 +238,7 @@ impl Kernels {
}
}
+#[allow(clippy::too_many_arguments)]
pub fn call_unary_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -270,8 +248,6 @@ pub fn call_unary_contiguous(
input: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
- // println!("Kernel {:?}", kernel_name.0);
- // assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
@@ -292,6 +268,8 @@ pub fn call_unary_contiguous(
encoder.end_encoding();
Ok(())
}
+
+#[allow(clippy::too_many_arguments)]
pub fn call_unary_strided(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -339,6 +317,7 @@ pub fn call_unary_strided(
Ok(())
}
+#[allow(clippy::too_many_arguments)]
pub fn call_binary_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -349,8 +328,6 @@ pub fn call_binary_contiguous(
right: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
- // println!("Kernel {:?}", kernel_name.0);
- // assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Binary, kernel_name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
@@ -373,6 +350,7 @@ pub fn call_binary_contiguous(
Ok(())
}
+#[allow(clippy::too_many_arguments)]
pub fn call_binary_strided(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -425,6 +403,7 @@ pub fn call_binary_strided(
Ok(())
}
+#[allow(clippy::too_many_arguments)]
pub fn call_cast_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -434,8 +413,6 @@ pub fn call_cast_contiguous(
input: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
- // println!("Kernel {:?}", kernel_name.0);
- // assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
@@ -458,6 +435,7 @@ pub fn call_cast_contiguous(
Ok(())
}
+#[allow(clippy::too_many_arguments)]
pub fn call_reduce_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -508,6 +486,7 @@ pub fn call_reduce_contiguous(
Ok(())
}
+#[allow(clippy::too_many_arguments)]
pub fn call_last_softmax(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -543,7 +522,6 @@ pub fn call_last_softmax(
let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
- // (elements_to_sum as u64 + 2 - 1) / 2,
elements_to_sum as u64,
)
.next_power_of_two();
@@ -559,6 +537,7 @@ pub fn call_last_softmax(
Ok(())
}
+#[allow(clippy::too_many_arguments)]
pub fn call_affine(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -590,6 +569,7 @@ pub fn call_affine(
Ok(())
}
+#[allow(clippy::too_many_arguments)]
pub fn call_where_cond_strided(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -643,6 +623,7 @@ pub fn call_where_cond_strided(
Ok(())
}
+#[allow(clippy::too_many_arguments)]
pub fn call_index_select(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -813,7 +794,6 @@ mod tests {
#[test]
fn cos_f32_strided() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
- // Shape = [6], strides = [1];
let shape = vec![6];
let strides = vec![1];
let offset = 0;