diff options
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a595b2bd..a270bb28 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -14,6 +14,7 @@ const AFFINE: &str = include_str!("affine.metal"); const BINARY: &str = include_str!("binary.metal"); const CAST: &str = include_str!("cast.metal"); const CONV: &str = include_str!("conv.metal"); +const FILL: &str = include_str!("fill.metal"); const INDEXING: &str = include_str!("indexing.metal"); // Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); @@ -31,6 +32,7 @@ pub enum Source { Binary, Cast, Conv, + Fill, Gemm, Indexing, Mfa, @@ -196,6 +198,7 @@ impl Kernels { Source::Binary => BINARY, Source::Cast => CAST, Source::Conv => CONV, + Source::Fill => FILL, Source::Gemm => MLX_GEMM, Source::Indexing => INDEXING, Source::Quantized => QUANTIZED, @@ -2357,5 +2360,30 @@ pub fn call_mlx_gemm( Ok(()) } +pub fn call_const_fill( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + length: usize, + output: &Buffer, + v: f32, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Fill, name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (output, v, length)); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + + Ok(()) +} + #[cfg(test)] mod tests; |