diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-12-20 15:37:31 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-20 15:37:31 +0100 |
commit | 9fc210fae8175a180dba8c28aa8e5975868a237c (patch) | |
tree | 5c009b11e1c11f20c99d1546849a00a063e068c0 /candle-nn | |
parent | 96f1a28e390fceeaa12b3272c8ac5dcccc8eb5fa (diff) | |
parent | 9b5e4843a63180a2803b1e836b4ca90f14281d03 (diff) | |
download | candle-9fc210fae8175a180dba8c28aa8e5975868a237c.tar.gz candle-9fc210fae8175a180dba8c28aa8e5975868a237c.tar.bz2 candle-9fc210fae8175a180dba8c28aa8e5975868a237c.zip |
Merge pull request #1318 from huggingface/metal4
Starting to fix some tests.
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/Cargo.toml | 3 | ||||
-rw-r--r-- | candle-nn/src/ops.rs | 41 |
2 files changed, 44 insertions, 0 deletions
diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index ffbe0ca1..e0daabef 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -19,6 +19,8 @@ num-traits = { workspace = true } rayon = { workspace = true } safetensors = { workspace = true } serde = { workspace = true } +metal = { workspace = true, optional = true } +candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } [dev-dependencies] anyhow = { workspace = true } @@ -29,3 +31,4 @@ default = [] accelerate = ["dep:accelerate-src", "candle/accelerate"] cuda = ["candle/cuda"] mkl = ["dep:intel-mkl-src", "candle/mkl"] +metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"] diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index a0269e59..abe33350 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -201,6 +201,47 @@ impl candle::CustomOp1 for SoftmaxLastDim { }; Ok((dst, layout.shape().clone())) } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + storage: &candle::MetalStorage, + layout: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::{backend::BackendStorage, DType}; + let device = storage.device(); + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + let name = match storage.dtype() { + DType::F32 => "softmax_f32", + DType::F16 => "softmax_f16", + DType::BF16 => "softmax_bf16", + dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"), + }; + + let n = layout.stride().len(); + if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) { + candle::bail!("Non contiguous softmax-last-dim is not implemented"); + } + + let last_dim = layout.dims()[layout.shape().rank() - 1]; + let elem_count = layout.shape().elem_count(); + let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?; + candle_metal_kernels::call_last_softmax( + device.metal_device(), + &command_buffer, + kernels, + name, + elem_count, + last_dim, + storage.buffer(), + layout.start_offset() * storage.dtype().size_in_bytes(), + &output, + ) + .unwrap(); + let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype()); + Ok((newstorage, layout.shape().clone())) + } } pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> { |