summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r--candle-core/src/cuda_backend.rs41
1 files changed, 38 insertions, 3 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 90d3ee6d..0a73c023 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -980,7 +980,7 @@ impl Map1 for Pool2D {
dev: &CudaDevice,
inp_l: &Layout,
) -> Result<CudaSlice<T>> {
- // Kernel shape: (c_out, c_in_k, w_k, h_k)
+ // Input shape: (b_size, c, h, w)
let inp = &inp.slice(inp_l.start_offset()..);
let shape = inp_l.shape();
let dims = shape.dims();
@@ -1018,6 +1018,39 @@ impl Map1 for Pool2D {
}
}
+struct UpsampleNearest2D(usize, usize);
+impl Map1 for UpsampleNearest2D {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ inp: &CudaSlice<T>,
+ dev: &CudaDevice,
+ inp_l: &Layout,
+ ) -> Result<CudaSlice<T>> {
+ // Input shape: (b_size, c, h, w)
+ let inp = &inp.slice(inp_l.start_offset()..);
+ let shape = inp_l.shape();
+ let dims = shape.dims();
+ let ds = if dims.len() == 4 {
+ [dims, inp_l.stride()].concat()
+ } else {
+ panic!("unexpected input shape for conv1d {dims:?}")
+ };
+ let (out_w, out_h) = (self.0, self.1);
+ let dst_el = out_w * out_h * dims[0] * dims[1];
+ let cfg = LaunchConfig::for_num_elems(dst_el as u32);
+ let func = dev.get_or_load_func(&kernel_name::<T>("upsample_nearest2d"), kernels::CONV)?;
+ // SAFETY: Set later by running the kernel.
+ let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
+ let ds = dev.htod_copy(ds).w()?;
+ let scale_w = dims[2] as f64 / out_w as f64;
+ let scale_h = dims[3] as f64 / out_h as f64;
+ let params = (out_w, out_h, scale_w, scale_h, &ds, inp, &out);
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(out)
+ }
+}
+
struct WhereCond<'a>(&'a CudaStorage, &'a Layout);
impl<'a> Map2 for WhereCond<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@@ -1513,8 +1546,10 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
- fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
- todo!()
+ fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
+ let device = self.device().clone();
+ let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?;
+ Ok(Self { slice, device })
}
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {