diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-21 09:48:56 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-21 09:48:56 +0100 |
commit | 0fddec762e3c17c56be5b6356478b9565dd628bb (patch) | |
tree | 49a1e09d3b397f97187f60739e80f528ae4b083a /candle-nn | |
parent | 74b7f59261c72010e329fd8eb467c088673671f5 (diff) | |
download | candle-0fddec762e3c17c56be5b6356478b9565dd628bb.tar.gz candle-0fddec762e3c17c56be5b6356478b9565dd628bb.tar.bz2 candle-0fddec762e3c17c56be5b6356478b9565dd628bb.zip |
RmsNorm kernel for metal. (#1895)
* RmsNorm kernel for metal.
* Wrapper for the metal kernel.
* Get the ops to actually work.
* Fix, get the tests to pass.
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/ops.rs | 47 |
1 files changed, 46 insertions, 1 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index d725bdc2..1dac8c3b 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -236,7 +236,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { layout.start_offset() * storage.dtype().size_in_bytes(), &output, ) - .unwrap(); + .map_err(candle::Error::wrap)?; let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, storage.dtype()); Ok((newstorage, layout.shape().clone())) @@ -383,6 +383,51 @@ impl candle::CustomOp2 for RmsNorm { }; Ok((dst, l1.shape().clone())) } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + s1: &candle::MetalStorage, + l1: &Layout, + s2: &candle::MetalStorage, + l2: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::backend::BackendStorage; + let device = s1.device(); + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + let name = match (s1.dtype(), s2.dtype()) { + (DType::F32, DType::F32) => "rmsnorm_f32", + (DType::F16, DType::F16) => "rmsnorm_f16", + (DType::BF16, DType::BF16) => "rmsnorm_bf16", + (dt1, dt2) => candle::bail!("rmsnorm is not implemented for {dt1:?} {dt2:?}"), + }; + + if !(l1.is_contiguous() && l2.is_contiguous()) { + candle::bail!("Non contiguous rmsnorm is not implemented"); + } + + let last_dim = l1.dims()[l1.shape().rank() - 1]; + let elem_count = l1.shape().elem_count(); + let output = device.new_buffer(elem_count, s1.dtype(), "rmsnorm")?; + candle_metal_kernels::call_rms_norm( + device.metal_device(), + &command_buffer, + kernels, + name, + elem_count, + last_dim, + self.eps, + s1.buffer(), + l1.start_offset() * s1.dtype().size_in_bytes(), + s2.buffer(), + l2.start_offset() * s2.dtype().size_in_bytes(), + &output, + ) + .map_err(candle::Error::wrap)?; + let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype()); + Ok((newstorage, l1.shape().clone())) + } } pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> { |