diff options
Diffstat (limited to 'candle-nn/src/ops.rs')
-rw-r--r-- | candle-nn/src/ops.rs | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index a0269e59..350bc663 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -201,6 +201,46 @@ impl candle::CustomOp1 for SoftmaxLastDim { }; Ok((dst, layout.shape().clone())) } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + storage: &candle::MetalStorage, + layout: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::{backend::BackendStorage, DType}; + let device = storage.device(); + let command_buffer = device.command_buffer(); + let kernels = device.kernels(); + let name = match storage.dtype() { + DType::F32 => "softmax_float", + DType::F16 => "softmax_half", + DType::BF16 => "softmax_bfloat", + dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"), + }; + + let n = layout.stride().len(); + if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) { + candle::bail!("Non contiguous softmax-last-dim is not implemented"); + } + + let last_dim = layout.dims()[layout.shape().rank() - 1]; + let elem_count = layout.shape().elem_count(); + let mut output = device.new_buffer(elem_count, storage.dtype()); + candle_metal_kernels::call_last_softmax( + device.metal_device(), + &command_buffer, + &kernels, + name, + elem_count, + last_dim, + storage.buffer(), + &mut output, + ) + .unwrap(); + let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype()); + Ok((newstorage, layout.shape().clone())) + } } pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> { |