summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend/utils.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/cuda_backend/utils.rs')
-rw-r--r--candle-core/src/cuda_backend/utils.rs38
1 files changed, 38 insertions, 0 deletions
diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs
index 8dd5be77..c1210727 100644
--- a/candle-core/src/cuda_backend/utils.rs
+++ b/candle-core/src/cuda_backend/utils.rs
@@ -54,6 +54,44 @@ pub trait Map2 {
}
}
+pub trait Map3 {
+ #[allow(clippy::too_many_arguments)]
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ src1: &CudaSlice<T>,
+ layout1: &Layout,
+ src2: &CudaSlice<T>,
+ layout2: &Layout,
+ src3: &CudaSlice<T>,
+ layout3: &Layout,
+ dev: &CudaDevice,
+ ) -> Result<CudaSlice<T>>;
+
+ #[allow(clippy::too_many_arguments)]
+ fn map(
+ &self,
+ s1: &S,
+ l1: &Layout,
+ s2: &S,
+ l2: &Layout,
+ s3: &S,
+ l3: &Layout,
+ d: &CudaDevice,
+ ) -> Result<S> {
+ let out = match (s1, s2, s3) {
+ (S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ (S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ (S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ (S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ (S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ (S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ (S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ _ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?,
+ };
+ Ok(out)
+ }
+}
+
pub trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,