summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/Cargo.toml2
-rw-r--r--candle-nn/src/ops.rs40
2 files changed, 42 insertions, 0 deletions
diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml
index d3f43c73..45298907 100644
--- a/candle-nn/Cargo.toml
+++ b/candle-nn/Cargo.toml
@@ -19,6 +19,7 @@ num-traits = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
+candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
[dev-dependencies]
anyhow = { workspace = true }
@@ -29,3 +30,4 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
+metal = ["candle/metal", "dep:candle-metal-kernels"]
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index a0269e59..350bc663 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -201,6 +201,46 @@ impl candle::CustomOp1 for SoftmaxLastDim {
};
Ok((dst, layout.shape().clone()))
}
+
+ #[cfg(feature = "metal")]
+ fn metal_fwd(
+ &self,
+ storage: &candle::MetalStorage,
+ layout: &Layout,
+ ) -> Result<(candle::MetalStorage, Shape)> {
+ use candle::{backend::BackendStorage, DType};
+ let device = storage.device();
+ let command_buffer = device.command_buffer();
+ let kernels = device.kernels();
+ let name = match storage.dtype() {
+ DType::F32 => "softmax_float",
+ DType::F16 => "softmax_half",
+ DType::BF16 => "softmax_bfloat",
+ dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
+ };
+
+ let n = layout.stride().len();
+ if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
+ candle::bail!("Non contiguous softmax-last-dim is not implemented");
+ }
+
+ let last_dim = layout.dims()[layout.shape().rank() - 1];
+ let elem_count = layout.shape().elem_count();
+ let mut output = device.new_buffer(elem_count, storage.dtype());
+ candle_metal_kernels::call_last_softmax(
+ device.metal_device(),
+ &command_buffer,
+ &kernels,
+ name,
+ elem_count,
+ last_dim,
+ storage.buffer(),
+ &mut output,
+ )
+ .unwrap();
+ let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
+ Ok((newstorage, layout.shape().clone()))
+ }
}
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {