summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-12-20 15:37:31 +0100
committerGitHub <noreply@github.com>2023-12-20 15:37:31 +0100
commit9fc210fae8175a180dba8c28aa8e5975868a237c (patch)
tree5c009b11e1c11f20c99d1546849a00a063e068c0 /candle-nn
parent96f1a28e390fceeaa12b3272c8ac5dcccc8eb5fa (diff)
parent9b5e4843a63180a2803b1e836b4ca90f14281d03 (diff)
downloadcandle-9fc210fae8175a180dba8c28aa8e5975868a237c.tar.gz
candle-9fc210fae8175a180dba8c28aa8e5975868a237c.tar.bz2
candle-9fc210fae8175a180dba8c28aa8e5975868a237c.zip
Merge pull request #1318 from huggingface/metal4
Starting to fix some tests.
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/Cargo.toml3
-rw-r--r--candle-nn/src/ops.rs41
2 files changed, 44 insertions, 0 deletions
diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml
index ffbe0ca1..e0daabef 100644
--- a/candle-nn/Cargo.toml
+++ b/candle-nn/Cargo.toml
@@ -19,6 +19,8 @@ num-traits = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
+metal = { workspace = true, optional = true }
+candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
[dev-dependencies]
anyhow = { workspace = true }
@@ -29,3 +31,4 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
+metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"]
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index a0269e59..abe33350 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -201,6 +201,47 @@ impl candle::CustomOp1 for SoftmaxLastDim {
};
Ok((dst, layout.shape().clone()))
}
+
+ #[cfg(feature = "metal")]
+ fn metal_fwd(
+ &self,
+ storage: &candle::MetalStorage,
+ layout: &Layout,
+ ) -> Result<(candle::MetalStorage, Shape)> {
+ use candle::{backend::BackendStorage, DType};
+ let device = storage.device();
+ let command_buffer = device.command_buffer()?;
+ let kernels = device.kernels();
+ let name = match storage.dtype() {
+ DType::F32 => "softmax_f32",
+ DType::F16 => "softmax_f16",
+ DType::BF16 => "softmax_bf16",
+ dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
+ };
+
+ let n = layout.stride().len();
+ if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
+ candle::bail!("Non contiguous softmax-last-dim is not implemented");
+ }
+
+ let last_dim = layout.dims()[layout.shape().rank() - 1];
+ let elem_count = layout.shape().elem_count();
+ let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
+ candle_metal_kernels::call_last_softmax(
+ device.metal_device(),
+ &command_buffer,
+ kernels,
+ name,
+ elem_count,
+ last_dim,
+ storage.buffer(),
+ layout.start_offset() * storage.dtype().size_in_bytes(),
+ &output,
+ )
+ .unwrap();
+ let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
+ Ok((newstorage, layout.shape().clone()))
+ }
}
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {