summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-11-11 01:02:15 +0100
committerNicolas Patry <patry.nicolas@protonmail.com>2023-11-30 11:30:31 +0100
commit4349ff1fc29a1a25b2ccdf56fbf68a98f5364c0a (patch)
tree78a6b3533670a33f7bc2f75851fac24307a46fed /candle-nn/src
parent7c3cfd1086ecdc08a0b350f30f1fbedf2f00c269 (diff)
downloadcandle-4349ff1fc29a1a25b2ccdf56fbf68a98f5364c0a.tar.gz
candle-4349ff1fc29a1a25b2ccdf56fbf68a98f5364c0a.tar.bz2
candle-4349ff1fc29a1a25b2ccdf56fbf68a98f5364c0a.zip
Starting to fix some tests.
Few fixes. Going back on remote metal-rs. Reusing a single buffer (for now) to speed things up. Adding some half kernels. All tests are panicking instead of random failure. Putting back f16 index select. Add erf. Working version for llama2-c. Fixes + cache compute_pipeline_state. BF16 metal fix. Remove some prints. new_owned -> new()..to_owned(). Better batched matmul. Metal operational. Reuse buffers on our own reference counts. Tmp gemm. Revert "Tmp gemm." This reverts commit c65f68e98814b65daa596696bda076a73303dd82. Interleave committing. Speeding up copies using blit. Fmt. Fmt. Remove the assert! Fmt all. Fixes after big rebase. Add softmax for half and bfloat + tests Fixing Llama example + accumulate softmax in float.
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/ops.rs40
1 files changed, 40 insertions, 0 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index a0269e59..350bc663 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -201,6 +201,46 @@ impl candle::CustomOp1 for SoftmaxLastDim {
};
Ok((dst, layout.shape().clone()))
}
+
+ #[cfg(feature = "metal")]
+ fn metal_fwd(
+ &self,
+ storage: &candle::MetalStorage,
+ layout: &Layout,
+ ) -> Result<(candle::MetalStorage, Shape)> {
+ use candle::{backend::BackendStorage, DType};
+ let device = storage.device();
+ let command_buffer = device.command_buffer();
+ let kernels = device.kernels();
+ let name = match storage.dtype() {
+ DType::F32 => "softmax_float",
+ DType::F16 => "softmax_half",
+ DType::BF16 => "softmax_bfloat",
+ dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
+ };
+
+ let n = layout.stride().len();
+ if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
+ candle::bail!("Non contiguous softmax-last-dim is not implemented");
+ }
+
+ let last_dim = layout.dims()[layout.shape().rank() - 1];
+ let elem_count = layout.shape().elem_count();
+ let mut output = device.new_buffer(elem_count, storage.dtype());
+ candle_metal_kernels::call_last_softmax(
+ device.metal_device(),
+ &command_buffer,
+ &kernels,
+ name,
+ elem_count,
+ last_dim,
+ storage.buffer(),
+ &mut output,
+ )
+ .unwrap();
+ let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
+ Ok((newstorage, layout.shape().clone()))
+ }
}
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {