summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authornicolas <nicolas@nicolass-MacBook-Pro.local>2023-12-12 17:41:56 +0100
committernicolas <nicolas@nicolass-MacBook-Pro.local>2023-12-12 17:41:56 +0100
commit87dc559817db11f8d8c409cda959528e57e1db31 (patch)
tree3f7ec04a0facab3378158ae3ba84416d56fd37a7 /candle-nn
parentda0af3cb3e58d38476a20f4465744093a3b75dd4 (diff)
downloadcandle-87dc559817db11f8d8c409cda959528e57e1db31.tar.gz
candle-87dc559817db11f8d8c409cda959528e57e1db31.tar.bz2
candle-87dc559817db11f8d8c409cda959528e57e1db31.zip
Lots of updates including some stack of command buffers.
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/Cargo.toml3
-rw-r--r--candle-nn/src/ops.rs4
2 files changed, 5 insertions, 2 deletions
diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml
index 45298907..03622752 100644
--- a/candle-nn/Cargo.toml
+++ b/candle-nn/Cargo.toml
@@ -19,6 +19,7 @@ num-traits = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
+metal = { workspace = true, optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
[dev-dependencies]
@@ -30,4 +31,4 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
-metal = ["candle/metal", "dep:candle-metal-kernels"]
+metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"]
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index 350bc663..14dd10de 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -226,7 +226,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
let last_dim = layout.dims()[layout.shape().rank() - 1];
let elem_count = layout.shape().elem_count();
- let mut output = device.new_buffer(elem_count, storage.dtype());
+ let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax");
candle_metal_kernels::call_last_softmax(
device.metal_device(),
&command_buffer,
@@ -238,6 +238,8 @@ impl candle::CustomOp1 for SoftmaxLastDim {
&mut output,
)
.unwrap();
+ command_buffer.commit();
+ output.did_modify_range(metal::NSRange::new(0, output.length()));
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
Ok((newstorage, layout.shape().clone()))
}