summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r--candle-metal-kernels/src/lib.rs28
1 files changed, 28 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index a595b2bd..a270bb28 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -14,6 +14,7 @@ const AFFINE: &str = include_str!("affine.metal");
const BINARY: &str = include_str!("binary.metal");
const CAST: &str = include_str!("cast.metal");
const CONV: &str = include_str!("conv.metal");
+const FILL: &str = include_str!("fill.metal");
const INDEXING: &str = include_str!("indexing.metal");
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
@@ -31,6 +32,7 @@ pub enum Source {
Binary,
Cast,
Conv,
+ Fill,
Gemm,
Indexing,
Mfa,
@@ -196,6 +198,7 @@ impl Kernels {
Source::Binary => BINARY,
Source::Cast => CAST,
Source::Conv => CONV,
+ Source::Fill => FILL,
Source::Gemm => MLX_GEMM,
Source::Indexing => INDEXING,
Source::Quantized => QUANTIZED,
@@ -2357,5 +2360,30 @@ pub fn call_mlx_gemm(
Ok(())
}
+pub fn call_const_fill(
+ device: &Device,
+ ep: impl EncoderProvider,
+ kernels: &Kernels,
+ name: &'static str,
+ length: usize,
+ output: &Buffer,
+ v: f32,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
+ let encoder = ep.encoder();
+ let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
+
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(encoder, (output, v, length));
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
+
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+
+ Ok(())
+}
+
#[cfg(test)]
mod tests;