diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-29 19:48:04 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-29 19:48:04 +0100 |
commit | c950a5c6b1bd207f6f9cba18780e0baa44e3254a (patch) | |
tree | dda49e92ece97b85f9260eba714963d94f694525 /candle-core/src/cuda_backend.rs | |
parent | 16c33383eb2beda515962b219728209b9edb2946 (diff) | |
download | candle-c950a5c6b1bd207f6f9cba18780e0baa44e3254a.tar.gz candle-c950a5c6b1bd207f6f9cba18780e0baa44e3254a.tar.bz2 candle-c950a5c6b1bd207f6f9cba18780e0baa44e3254a.zip |
Cuda support for the mnist training. (#277)
* Cuda support for the mnist training.
* min/max fix + testing.
* Add the argmin/argmax tests.
* More cuda support for argmin/argmax.
* Cuda kernels for argmin and argmax.
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r-- | candle-core/src/cuda_backend.rs | 65 |
1 files changed, 50 insertions, 15 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index a88d62c7..6c98cd0a 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -438,6 +438,28 @@ trait Map2InPlace { } } +trait Map1Any { + fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>( + &self, + src: &CudaSlice<T>, + dev: &CudaDevice, + layout: &Layout, + wrap: W, + ) -> Result<S>; + + fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> { + let out = match s { + S::U8(s) => self.f(s, d, l, S::U8)?, + S::U32(s) => self.f(s, d, l, S::U32)?, + S::BF16(s) => self.f(s, d, l, S::BF16)?, + S::F16(s) => self.f(s, d, l, S::F16)?, + S::F32(s) => self.f(s, d, l, S::F32)?, + S::F64(s) => self.f(s, d, l, S::F64)?, + }; + Ok(out) + } +} + trait Map2Any { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, @@ -574,13 +596,14 @@ impl<'a> Map1 for Sum<'a> { } struct FastReduce<'a>(&'a [usize], ReduceOp); -impl<'a> Map1 for FastReduce<'a> { - fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( +impl<'a> Map1Any for FastReduce<'a> { + fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>( &self, src: &CudaSlice<T>, dev: &CudaDevice, layout: &Layout, - ) -> Result<CudaSlice<T>> { + wrap: W, + ) -> Result<S> { let src_stride = layout.stride(); let src_dims = layout.shape().dims(); let src_el: usize = src_dims.iter().product(); @@ -615,20 +638,32 @@ impl<'a> Map1 for FastReduce<'a> { .htod_copy([dims.as_slice(), stride.as_slice()].concat()) .w()?; let src = &src.slice(layout.start_offset()..); - let name = match self.1 { - ReduceOp::Sum => "fast_sum", - ReduceOp::Min => "fast_min", - ReduceOp::Max => "fast_max", - ReduceOp::ArgMin => "fast_argmin", - ReduceOp::ArgMax => "fast_argmax", + let (name, check_empty, return_index) = match self.1 { + ReduceOp::Sum => ("fast_sum", false, false), + ReduceOp::Min => ("fast_min", true, false), + ReduceOp::Max => ("fast_max", true, false), + ReduceOp::ArgMin => ("fast_argmin", true, true), + ReduceOp::ArgMax => ("fast_argmax", true, true), }; + if check_empty && layout.shape().elem_count() == 0 { + Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? + } let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?; - // SAFETY: filled in by the follow up kernel. - let out = unsafe { dev.alloc::<T>(dst_el) }.w()?; - let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }.w()?; - Ok(out) + if return_index { + // SAFETY: filled in by the follow up kernel. + let out = unsafe { dev.alloc::<u32>(dst_el) }.w()?; + let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(S::U32(out)) + } else { + // SAFETY: filled in by the follow up kernel. + let out = unsafe { dev.alloc::<T>(dst_el) }.w()?; + let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(wrap(out)) + } } } |