diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-05-24 15:58:01 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-24 15:58:01 +0200 |
commit | 1df2bddccfbb4ab511a8cc3a87476d1fa72416bc (patch) | |
tree | 3633bc51e3bac3d542d9dfe06d509db20f5374e9 /candle-core/src/cuda_backend | |
parent | 6f0b807ffd553fed27325a2a118b0e30bb6d9cbd (diff) | |
download | candle-1df2bddccfbb4ab511a8cc3a87476d1fa72416bc.tar.gz candle-1df2bddccfbb4ab511a8cc3a87476d1fa72416bc.tar.bz2 candle-1df2bddccfbb4ab511a8cc3a87476d1fa72416bc.zip |
Add the layernorm specialized op. (#2212)
* Add the layernorm cuda kernels.
* Dedicated layer norm op.
* Add the slower variant.
* Plug the cuda implementation.
* Add the metal variant.
* Add a dedicated test.
* Bugfix.
Diffstat (limited to 'candle-core/src/cuda_backend')
-rw-r--r-- | candle-core/src/cuda_backend/mod.rs | 2 | ||||
-rw-r--r-- | candle-core/src/cuda_backend/utils.rs | 38 |
2 files changed, 39 insertions, 1 deletions
diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 88f325f4..9e72dcc8 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -16,7 +16,7 @@ mod error; mod utils; pub use device::{CudaDevice, DeviceId}; pub use error::{CudaError, WrapErr}; -pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S}; +pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, Map3, S}; pub enum SlicePtrOrNull<T> { Ptr(CudaSlice<T>), diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs index 8dd5be77..c1210727 100644 --- a/candle-core/src/cuda_backend/utils.rs +++ b/candle-core/src/cuda_backend/utils.rs @@ -54,6 +54,44 @@ pub trait Map2 { } } +pub trait Map3 { + #[allow(clippy::too_many_arguments)] + fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( + &self, + src1: &CudaSlice<T>, + layout1: &Layout, + src2: &CudaSlice<T>, + layout2: &Layout, + src3: &CudaSlice<T>, + layout3: &Layout, + dev: &CudaDevice, + ) -> Result<CudaSlice<T>>; + + #[allow(clippy::too_many_arguments)] + fn map( + &self, + s1: &S, + l1: &Layout, + s2: &S, + l2: &Layout, + s3: &S, + l3: &Layout, + d: &CudaDevice, + ) -> Result<S> { + let out = match (s1, s2, s3) { + (S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?), + _ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?, + }; + Ok(out) + } +} + pub trait Map2InPlace { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, |