summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-21 09:48:56 +0100
committerGitHub <noreply@github.com>2024-03-21 09:48:56 +0100
commit0fddec762e3c17c56be5b6356478b9565dd628bb (patch)
tree49a1e09d3b397f97187f60739e80f528ae4b083a /candle-nn
parent74b7f59261c72010e329fd8eb467c088673671f5 (diff)
downloadcandle-0fddec762e3c17c56be5b6356478b9565dd628bb.tar.gz
candle-0fddec762e3c17c56be5b6356478b9565dd628bb.tar.bz2
candle-0fddec762e3c17c56be5b6356478b9565dd628bb.zip
RmsNorm kernel for metal. (#1895)
* RmsNorm kernel for metal. * Wrapper for the metal kernel. * Get the ops to actually work. * Fix, get the tests to pass.
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/ops.rs47
1 files changed, 46 insertions, 1 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index d725bdc2..1dac8c3b 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -236,7 +236,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
layout.start_offset() * storage.dtype().size_in_bytes(),
&output,
)
- .unwrap();
+ .map_err(candle::Error::wrap)?;
let newstorage =
candle::MetalStorage::new(output, device.clone(), elem_count, storage.dtype());
Ok((newstorage, layout.shape().clone()))
@@ -383,6 +383,51 @@ impl candle::CustomOp2 for RmsNorm {
};
Ok((dst, l1.shape().clone()))
}
+
+ #[cfg(feature = "metal")]
+ fn metal_fwd(
+ &self,
+ s1: &candle::MetalStorage,
+ l1: &Layout,
+ s2: &candle::MetalStorage,
+ l2: &Layout,
+ ) -> Result<(candle::MetalStorage, Shape)> {
+ use candle::backend::BackendStorage;
+ let device = s1.device();
+ let command_buffer = device.command_buffer()?;
+ let kernels = device.kernels();
+ let name = match (s1.dtype(), s2.dtype()) {
+ (DType::F32, DType::F32) => "rmsnorm_f32",
+ (DType::F16, DType::F16) => "rmsnorm_f16",
+ (DType::BF16, DType::BF16) => "rmsnorm_bf16",
+ (dt1, dt2) => candle::bail!("rmsnorm is not implemented for {dt1:?} {dt2:?}"),
+ };
+
+ if !(l1.is_contiguous() && l2.is_contiguous()) {
+ candle::bail!("Non contiguous rmsnorm is not implemented");
+ }
+
+ let last_dim = l1.dims()[l1.shape().rank() - 1];
+ let elem_count = l1.shape().elem_count();
+ let output = device.new_buffer(elem_count, s1.dtype(), "rmsnorm")?;
+ candle_metal_kernels::call_rms_norm(
+ device.metal_device(),
+ &command_buffer,
+ kernels,
+ name,
+ elem_count,
+ last_dim,
+ self.eps,
+ s1.buffer(),
+ l1.start_offset() * s1.dtype().size_in_bytes(),
+ s2.buffer(),
+ l2.start_offset() * s2.dtype().size_in_bytes(),
+ &output,
+ )
+ .map_err(candle::Error::wrap)?;
+ let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
+ Ok((newstorage, l1.shape().clone()))
+ }
}
pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {