diff options
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a23aa47c..f2db171e 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -597,6 +597,7 @@ pub fn call_last_softmax( length: usize, elements_to_sum: usize, input: &Buffer, + input_offset: usize, output: &Buffer, ) -> Result<(), MetalKernelError> { let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; @@ -604,7 +605,10 @@ pub fn call_last_softmax( encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, elements_to_sum, input, output)); + set_params!( + encoder, + (length, elements_to_sum, (input, input_offset), output) + ); let out_length = length / elements_to_sum; |