summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cuda_backend.rs156
1 files changed, 72 insertions, 84 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 7d06dd72..0e9c11c8 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -244,6 +244,7 @@ enum CudaStorageSlice {
F32(CudaSlice<f32>),
F64(CudaSlice<f64>),
}
+type S = CudaStorageSlice;
trait Map1 {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@@ -253,13 +254,36 @@ trait Map1 {
layout: &Layout,
) -> Result<CudaSlice<T>>;
- fn map(&self, s: &CudaStorageSlice, d: &CudaDevice, l: &Layout) -> Result<CudaStorageSlice> {
+ fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
let out = match s {
- CudaStorageSlice::U32(s) => CudaStorageSlice::U32(self.f(s, d, l)?),
- CudaStorageSlice::BF16(s) => CudaStorageSlice::BF16(self.f(s, d, l)?),
- CudaStorageSlice::F16(s) => CudaStorageSlice::F16(self.f(s, d, l)?),
- CudaStorageSlice::F32(s) => CudaStorageSlice::F32(self.f(s, d, l)?),
- CudaStorageSlice::F64(s) => CudaStorageSlice::F64(self.f(s, d, l)?),
+ S::U32(s) => S::U32(self.f(s, d, l)?),
+ S::BF16(s) => S::BF16(self.f(s, d, l)?),
+ S::F16(s) => S::F16(self.f(s, d, l)?),
+ S::F32(s) => S::F32(self.f(s, d, l)?),
+ S::F64(s) => S::F64(self.f(s, d, l)?),
+ };
+ Ok(out)
+ }
+}
+
+trait Map2 {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ src1: &CudaSlice<T>,
+ layout1: &Layout,
+ src2: &CudaSlice<T>,
+ layout2: &Layout,
+ dev: &CudaDevice,
+ ) -> Result<CudaSlice<T>>;
+
+ fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
+ let out = match (s1, s2) {
+ (S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
+ (S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
+ (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
+ (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
+ (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
+ _ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
};
Ok(out)
}
@@ -411,6 +435,44 @@ impl<'a> Map1 for Embedding<'a> {
}
}
+struct WhereCond<'a>(&'a CudaStorage, &'a Layout);
+impl<'a> Map2 for WhereCond<'a> {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ t: &CudaSlice<T>,
+ layout_t: &Layout,
+ f: &CudaSlice<T>,
+ layout_f: &Layout,
+ dev: &CudaDevice,
+ ) -> Result<CudaSlice<T>> {
+ let ids_l = &self.1;
+ let ids = match &self.0.slice {
+ CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
+ _ => Err(CudaError::UnexpectedDType {
+ msg: "where conditions should be u32",
+ expected: DType::U32,
+ got: self.0.dtype(),
+ })?,
+ };
+ let ids = &ids;
+ let shape = ids_l.shape();
+ let dims = shape.dims();
+ let el = shape.elem_count();
+ let cfg = LaunchConfig::for_num_elems(el as u32);
+ let ds =
+ dev.htod_copy([dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())?;
+ let t = &t.slice(layout_t.start_offset()..);
+ let f = &f.slice(layout_f.start_offset()..);
+ let func = dev.get_or_load_func(&kernel_name::<T>("where"), kernels::TERNARY)?;
+ // SAFETY: Set later by running the kernel.
+ let out = unsafe { dev.alloc::<T>(el) }?;
+ let params = (el, dims.len(), &ds, ids, t, f, &out);
+ // SAFETY: ffi
+ unsafe { func.launch(cfg, params) }?;
+ Ok(out)
+ }
+}
+
fn slice_src_and_dst<'a, T>(
src: &'a CudaSlice<T>,
src_l: &Layout,
@@ -714,86 +776,12 @@ impl CudaStorage {
&self,
layout: &Layout,
t: &Self,
- layout_t: &Layout,
+ t_l: &Layout,
f: &Self,
- layout_f: &Layout,
+ f_l: &Layout,
) -> Result<Self> {
- let ids = match &self.slice {
- CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..),
- _ => Err(CudaError::UnexpectedDType {
- msg: "where conditions should be u32",
- expected: DType::U32,
- got: self.dtype(),
- })?,
- };
- let ids = &ids;
- let shape = layout.shape();
- let dims = shape.dims();
- let el = shape.elem_count();
- let cfg = LaunchConfig::for_num_elems(el as u32);
- let dev = self.device();
- let ds =
- dev.htod_copy([dims, layout.stride(), layout_t.stride(), layout_f.stride()].concat())?;
- let slice = match (&t.slice, &f.slice) {
- (CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => {
- let t = &t.slice(layout_t.start_offset()..);
- let f = &f.slice(layout_f.start_offset()..);
- let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<bf16>(el) }?;
- let params = (el, dims.len(), &ds, ids, t, f, &out);
- // SAFETY: ffi
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::BF16(out)
- }
- (CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => {
- let t = &t.slice(layout_t.start_offset()..);
- let f = &f.slice(layout_f.start_offset()..);
- let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<f16>(el) }?;
- let params = (el, dims.len(), &ds, ids, t, f, &out);
- // SAFETY: ffi
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F16(out)
- }
- (CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => {
- let t = &t.slice(layout_t.start_offset()..);
- let f = &f.slice(layout_f.start_offset()..);
- let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<f32>(el) }?;
- let params = (el, dims.len(), &ds, ids, t, f, &out);
- // SAFETY: ffi
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F32(out)
- }
- (CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => {
- let t = &t.slice(layout_t.start_offset()..);
- let f = &f.slice(layout_f.start_offset()..);
- // SAFETY: Set later by running the kernel.
- let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?;
- let out = unsafe { dev.alloc::<f64>(el) }?;
- let params = (el, dims.len(), &ds, ids, t, f, &out);
- // SAFETY: ffi
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F64(out)
- }
- (CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => {
- let t = &t.slice(layout_t.start_offset()..);
- let f = &f.slice(layout_f.start_offset()..);
- // SAFETY: Set later by running the kernel.
- let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?;
- let out = unsafe { dev.alloc::<u32>(el) }?;
- let params = (el, dims.len(), &ds, ids, t, f, &out);
- // SAFETY: ffi
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::U32(out)
- }
- // The dtypes should have been checked at this point so this is an internal error.
- _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
- };
- let device = dev.clone();
+ let device = self.device().clone();
+ let slice = WhereCond(self, layout).map(&t.slice, t_l, &f.slice, f_l, &device)?;
Ok(Self { slice, device })
}