diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/cuda_backend/mod.rs | 2 | ||||
-rw-r--r-- | candle-core/src/cuda_backend/utils.rs | 38 |
2 files changed, 39 insertions, 1 deletions
diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 88f325f4..9e72dcc8 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -16,7 +16,7 @@ mod error; mod utils; pub use device::{CudaDevice, DeviceId}; pub use error::{CudaError, WrapErr}; -pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S}; +pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, Map3, S}; pub enum SlicePtrOrNull<T> { Ptr(CudaSlice<T>), diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs index 8dd5be77..c1210727 100644 --- a/candle-core/src/cuda_backend/utils.rs +++ b/candle-core/src/cuda_backend/utils.rs @@ -54,6 +54,44 @@ pub trait Map2 { } } +pub trait Map3 { + #[allow(clippy::too_many_arguments)] + fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( + &self, + src1: &CudaSlice<T>, + layout1: &Layout, + src2: &CudaSlice<T>, + layout2: &Layout, + src3: &CudaSlice<T>, + layout3: &Layout, + dev: &CudaDevice, + ) -> Result<CudaSlice<T>>; + + #[allow(clippy::too_many_arguments)] + fn map( + &self, + s1: &S, + l1: &Layout, + s2: &S, + l2: &Layout, + s3: &S, + l3: &Layout, + d: &CudaDevice, + ) -> Result<S> { + let out = match (s1, s2, s3) { + (S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?), + _ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?, + }; + Ok(out) + } +} + pub trait Map2InPlace { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, |