From 1df2bddccfbb4ab511a8cc3a87476d1fa72416bc Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 24 May 2024 15:58:01 +0200 Subject: Add the layernorm specialized op. (#2212) * Add the layernorm cuda kernels. * Dedicated layer norm op. * Add the slower variant. * Plug the cuda implementation. * Add the metal variant. * Add a dedicated test. * Bugfix. --- candle-core/src/cuda_backend/utils.rs | 38 +++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) (limited to 'candle-core/src/cuda_backend/utils.rs') diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs index 8dd5be77..c1210727 100644 --- a/candle-core/src/cuda_backend/utils.rs +++ b/candle-core/src/cuda_backend/utils.rs @@ -54,6 +54,44 @@ pub trait Map2 { } } +pub trait Map3 { + #[allow(clippy::too_many_arguments)] + fn f( + &self, + src1: &CudaSlice, + layout1: &Layout, + src2: &CudaSlice, + layout2: &Layout, + src3: &CudaSlice, + layout3: &Layout, + dev: &CudaDevice, + ) -> Result>; + + #[allow(clippy::too_many_arguments)] + fn map( + &self, + s1: &S, + l1: &Layout, + s2: &S, + l2: &Layout, + s3: &S, + l3: &Layout, + d: &CudaDevice, + ) -> Result { + let out = match (s1, s2, s3) { + (S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?), + _ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?, + }; + Ok(out) + } +} + pub trait Map2InPlace { fn f( &self, -- cgit v1.2.3