summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-05-24 15:58:01 +0200
committerGitHub <noreply@github.com>2024-05-24 15:58:01 +0200
commit1df2bddccfbb4ab511a8cc3a87476d1fa72416bc (patch)
tree3633bc51e3bac3d542d9dfe06d509db20f5374e9 /candle-core/src/cuda_backend
parent6f0b807ffd553fed27325a2a118b0e30bb6d9cbd (diff)
downloadcandle-1df2bddccfbb4ab511a8cc3a87476d1fa72416bc.tar.gz
candle-1df2bddccfbb4ab511a8cc3a87476d1fa72416bc.tar.bz2
candle-1df2bddccfbb4ab511a8cc3a87476d1fa72416bc.zip
Add the layernorm specialized op. (#2212)
* Add the layernorm cuda kernels. * Dedicated layer norm op. * Add the slower variant. * Plug the cuda implementation. * Add the metal variant. * Add a dedicated test. * Bugfix.
Diffstat (limited to 'candle-core/src/cuda_backend')
-rw-r--r--candle-core/src/cuda_backend/mod.rs2
-rw-r--r--candle-core/src/cuda_backend/utils.rs38
2 files changed, 39 insertions, 1 deletions
diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs
index 88f325f4..9e72dcc8 100644
--- a/candle-core/src/cuda_backend/mod.rs
+++ b/candle-core/src/cuda_backend/mod.rs
@@ -16,7 +16,7 @@ mod error;
mod utils;
pub use device::{CudaDevice, DeviceId};
pub use error::{CudaError, WrapErr};
-pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
+pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, Map3, S};
pub enum SlicePtrOrNull<T> {
Ptr(CudaSlice<T>),
diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs
index 8dd5be77..c1210727 100644
--- a/candle-core/src/cuda_backend/utils.rs
+++ b/candle-core/src/cuda_backend/utils.rs
@@ -54,6 +54,44 @@ pub trait Map2 {
}
}
+pub trait Map3 {
+ #[allow(clippy::too_many_arguments)]
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ src1: &CudaSlice<T>,
+ layout1: &Layout,
+ src2: &CudaSlice<T>,
+ layout2: &Layout,
+ src3: &CudaSlice<T>,
+ layout3: &Layout,
+ dev: &CudaDevice,
+ ) -> Result<CudaSlice<T>>;
+
+ #[allow(clippy::too_many_arguments)]
+ fn map(
+ &self,
+ s1: &S,
+ l1: &Layout,
+ s2: &S,
+ l2: &Layout,
+ s3: &S,
+ l3: &Layout,
+ d: &CudaDevice,
+ ) -> Result<S> {
+ let out = match (s1, s2, s3) {
+ (S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ (S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ (S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ (S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ (S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ (S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ (S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ _ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?,
+ };
+ Ok(out)
+ }
+}
+
pub trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,