summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r--candle-metal-kernels/src/lib.rs6
1 files changed, 5 insertions, 1 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index a23aa47c..f2db171e 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -597,6 +597,7 @@ pub fn call_last_softmax(
length: usize,
elements_to_sum: usize,
input: &Buffer,
+ input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
@@ -604,7 +605,10 @@ pub fn call_last_softmax(
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (length, elements_to_sum, input, output));
+ set_params!(
+ encoder,
+ (length, elements_to_sum, (input, input_offset), output)
+ );
let out_length = length / elements_to_sum;