diff options
author | nicolas <nicolas@nicolass-MacBook-Pro.local> | 2023-12-12 17:41:56 +0100 |
---|---|---|
committer | nicolas <nicolas@nicolass-MacBook-Pro.local> | 2023-12-12 17:41:56 +0100 |
commit | 87dc559817db11f8d8c409cda959528e57e1db31 (patch) | |
tree | 3f7ec04a0facab3378158ae3ba84416d56fd37a7 /candle-nn | |
parent | da0af3cb3e58d38476a20f4465744093a3b75dd4 (diff) | |
download | candle-87dc559817db11f8d8c409cda959528e57e1db31.tar.gz candle-87dc559817db11f8d8c409cda959528e57e1db31.tar.bz2 candle-87dc559817db11f8d8c409cda959528e57e1db31.zip |
Lots of updates including some stack of command buffers.
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/Cargo.toml | 3 | ||||
-rw-r--r-- | candle-nn/src/ops.rs | 4 |
2 files changed, 5 insertions, 2 deletions
diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index 45298907..03622752 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -19,6 +19,7 @@ num-traits = { workspace = true } rayon = { workspace = true } safetensors = { workspace = true } serde = { workspace = true } +metal = { workspace = true, optional = true } candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } [dev-dependencies] @@ -30,4 +31,4 @@ default = [] accelerate = ["dep:accelerate-src", "candle/accelerate"] cuda = ["candle/cuda"] mkl = ["dep:intel-mkl-src", "candle/mkl"] -metal = ["candle/metal", "dep:candle-metal-kernels"] +metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"] diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 350bc663..14dd10de 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -226,7 +226,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { let last_dim = layout.dims()[layout.shape().rank() - 1]; let elem_count = layout.shape().elem_count(); - let mut output = device.new_buffer(elem_count, storage.dtype()); + let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax"); candle_metal_kernels::call_last_softmax( device.metal_device(), &command_buffer, @@ -238,6 +238,8 @@ impl candle::CustomOp1 for SoftmaxLastDim { &mut output, ) .unwrap(); + command_buffer.commit(); + output.did_modify_range(metal::NSRange::new(0, output.length())); let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype()); Ok((newstorage, layout.shape().clone())) } |