summaryrefslogtreecommitdiff
path: root/candle-nn/src/ops.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src/ops.rs')
-rw-r--r--candle-nn/src/ops.rs40
1 files changed, 40 insertions, 0 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index a0269e59..350bc663 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -201,6 +201,46 @@ impl candle::CustomOp1 for SoftmaxLastDim {
};
Ok((dst, layout.shape().clone()))
}
+
+ #[cfg(feature = "metal")]
+ fn metal_fwd(
+ &self,
+ storage: &candle::MetalStorage,
+ layout: &Layout,
+ ) -> Result<(candle::MetalStorage, Shape)> {
+ use candle::{backend::BackendStorage, DType};
+ let device = storage.device();
+ let command_buffer = device.command_buffer();
+ let kernels = device.kernels();
+ let name = match storage.dtype() {
+ DType::F32 => "softmax_float",
+ DType::F16 => "softmax_half",
+ DType::BF16 => "softmax_bfloat",
+ dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
+ };
+
+ let n = layout.stride().len();
+ if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
+ candle::bail!("Non contiguous softmax-last-dim is not implemented");
+ }
+
+ let last_dim = layout.dims()[layout.shape().rank() - 1];
+ let elem_count = layout.shape().elem_count();
+ let mut output = device.new_buffer(elem_count, storage.dtype());
+ candle_metal_kernels::call_last_softmax(
+ device.metal_device(),
+ &command_buffer,
+ &kernels,
+ name,
+ elem_count,
+ last_dim,
+ storage.buffer(),
+ &mut output,
+ )
+ .unwrap();
+ let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
+ Ok((newstorage, layout.shape().clone()))
+ }
}
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {