summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend.rs
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-07-29 16:38:35 +0200
committerGitHub <noreply@github.com>2023-07-29 16:38:35 +0200
commit40c80bfbb2d43f1177eed83821fa1795d3ebea19 (patch)
tree1bc2c925c16d23498807dde1b75e5ca7c2b56cfe /candle-core/src/cuda_backend.rs
parent97d8712ba507dbdb06c639b0c6b8857e454bb269 (diff)
parent07eb899729cfcc8f2548103eed779c0e4c5b034c (diff)
downloadcandle-40c80bfbb2d43f1177eed83821fa1795d3ebea19.tar.gz
candle-40c80bfbb2d43f1177eed83821fa1795d3ebea19.tar.bz2
candle-40c80bfbb2d43f1177eed83821fa1795d3ebea19.zip
Merge branch 'main' into update_multiprocess
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r--candle-core/src/cuda_backend.rs16
1 files changed, 11 insertions, 5 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 4050b595..a88d62c7 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -940,16 +940,22 @@ impl<'a> Map2 for WhereCond<'a> {
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
let ids_l = &self.1;
- let ids = match &self.0.slice {
- CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
+ let (ids, name) = match &self.0.slice {
+ CudaStorageSlice::U8(slice) => {
+ let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
+ (ptr, "where_u8")
+ }
+ CudaStorageSlice::U32(slice) => {
+ let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
+ (ptr, "where_u32")
+ }
_ => Err(CudaError::UnexpectedDType {
- msg: "where conditions should be u32",
+ msg: "where conditions should be u8 or u32",
expected: DType::U32,
got: self.0.dtype(),
})
.w()?,
};
- let ids = &ids;
let shape = ids_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
@@ -959,7 +965,7 @@ impl<'a> Map2 for WhereCond<'a> {
.w()?;
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
- let func = dev.get_or_load_func(&kernel_name::<T>("where"), kernels::TERNARY)?;
+ let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (el, dims.len(), &ds, ids, t, f, &out);