diff options
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r-- | candle-core/src/cuda_backend.rs | 84 |
1 files changed, 80 insertions, 4 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 6129e100..90d3ee6d 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -960,6 +960,64 @@ impl<'a> Map2 for Conv2D<'a> { } } +enum PoolOp { + Max, + Avg, +} + +struct Pool2D { + w_k: usize, + h_k: usize, + w_stride: usize, + h_stride: usize, + op: PoolOp, +} + +impl Map1 for Pool2D { + fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( + &self, + inp: &CudaSlice<T>, + dev: &CudaDevice, + inp_l: &Layout, + ) -> Result<CudaSlice<T>> { + // Kernel shape: (c_out, c_in_k, w_k, h_k) + let inp = &inp.slice(inp_l.start_offset()..); + let shape = inp_l.shape(); + let dims = shape.dims(); + let ds = if dims.len() == 4 { + [dims, inp_l.stride()].concat() + } else { + panic!("unexpected input shape for conv1d {dims:?}") + }; + let el = shape.elem_count(); + let out_w = (dims[2] - self.w_k) / self.w_stride + 1; + let out_h = (dims[3] - self.h_k) / self.h_stride + 1; + let dst_el = out_w * out_h * dims[0] * dims[1]; + let cfg = LaunchConfig::for_num_elems(dst_el as u32); + let kname = match self.op { + PoolOp::Max => "max_pool2d", + PoolOp::Avg => "avg_pool2d", + }; + let func = dev.get_or_load_func(&kernel_name::<T>(kname), kernels::CONV)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::<T>(dst_el) }.w()?; + let ds = dev.htod_copy(ds).w()?; + let params = ( + el, + self.w_k, + self.h_k, + self.w_stride, + self.h_stride, + &ds, + inp, + &out, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(out) + } +} + struct WhereCond<'a>(&'a CudaStorage, &'a Layout); impl<'a> Map2 for WhereCond<'a> { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( @@ -1429,12 +1487,30 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } - fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> { - todo!() + fn avg_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> { + let device = self.device().clone(); + let slice = Pool2D { + w_k: k.0, + h_k: k.1, + w_stride: stride.0, + h_stride: stride.1, + op: PoolOp::Avg, + } + .map(&self.slice, &device, l)?; + Ok(Self { slice, device }) } - fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> { - todo!() + fn max_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> { + let device = self.device().clone(); + let slice = Pool2D { + w_k: k.0, + h_k: k.1, + w_stride: stride.0, + h_stride: stride.1, + op: PoolOp::Max, + } + .map(&self.slice, &device, l)?; + Ok(Self { slice, device }) } fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> { |