summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-29 19:48:04 +0100
committerGitHub <noreply@github.com>2023-07-29 19:48:04 +0100
commitc950a5c6b1bd207f6f9cba18780e0baa44e3254a (patch)
treedda49e92ece97b85f9260eba714963d94f694525 /candle-core/src/cuda_backend.rs
parent16c33383eb2beda515962b219728209b9edb2946 (diff)
downloadcandle-c950a5c6b1bd207f6f9cba18780e0baa44e3254a.tar.gz
candle-c950a5c6b1bd207f6f9cba18780e0baa44e3254a.tar.bz2
candle-c950a5c6b1bd207f6f9cba18780e0baa44e3254a.zip
Cuda support for the mnist training. (#277)
* Cuda support for the mnist training. * min/max fix + testing. * Add the argmin/argmax tests. * More cuda support for argmin/argmax. * Cuda kernels for argmin and argmax.
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r--candle-core/src/cuda_backend.rs65
1 files changed, 50 insertions, 15 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index a88d62c7..6c98cd0a 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -438,6 +438,28 @@ trait Map2InPlace {
}
}
+trait Map1Any {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
+ &self,
+ src: &CudaSlice<T>,
+ dev: &CudaDevice,
+ layout: &Layout,
+ wrap: W,
+ ) -> Result<S>;
+
+ fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
+ let out = match s {
+ S::U8(s) => self.f(s, d, l, S::U8)?,
+ S::U32(s) => self.f(s, d, l, S::U32)?,
+ S::BF16(s) => self.f(s, d, l, S::BF16)?,
+ S::F16(s) => self.f(s, d, l, S::F16)?,
+ S::F32(s) => self.f(s, d, l, S::F32)?,
+ S::F64(s) => self.f(s, d, l, S::F64)?,
+ };
+ Ok(out)
+ }
+}
+
trait Map2Any {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
@@ -574,13 +596,14 @@ impl<'a> Map1 for Sum<'a> {
}
struct FastReduce<'a>(&'a [usize], ReduceOp);
-impl<'a> Map1 for FastReduce<'a> {
- fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+impl<'a> Map1Any for FastReduce<'a> {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
- ) -> Result<CudaSlice<T>> {
+ wrap: W,
+ ) -> Result<S> {
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
let src_el: usize = src_dims.iter().product();
@@ -615,20 +638,32 @@ impl<'a> Map1 for FastReduce<'a> {
.htod_copy([dims.as_slice(), stride.as_slice()].concat())
.w()?;
let src = &src.slice(layout.start_offset()..);
- let name = match self.1 {
- ReduceOp::Sum => "fast_sum",
- ReduceOp::Min => "fast_min",
- ReduceOp::Max => "fast_max",
- ReduceOp::ArgMin => "fast_argmin",
- ReduceOp::ArgMax => "fast_argmax",
+ let (name, check_empty, return_index) = match self.1 {
+ ReduceOp::Sum => ("fast_sum", false, false),
+ ReduceOp::Min => ("fast_min", true, false),
+ ReduceOp::Max => ("fast_max", true, false),
+ ReduceOp::ArgMin => ("fast_argmin", true, true),
+ ReduceOp::ArgMax => ("fast_argmax", true, true),
};
+ if check_empty && layout.shape().elem_count() == 0 {
+ Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
+ }
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
- // SAFETY: filled in by the follow up kernel.
- let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
- let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }.w()?;
- Ok(out)
+ if return_index {
+ // SAFETY: filled in by the follow up kernel.
+ let out = unsafe { dev.alloc::<u32>(dst_el) }.w()?;
+ let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(S::U32(out))
+ } else {
+ // SAFETY: filled in by the follow up kernel.
+ let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
+ let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(wrap(out))
+ }
}
}