summaryrefslogtreecommitdiff
path: root/candle-core/src/cpu_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r--candle-core/src/cpu_backend.rs41
1 files changed, 41 insertions, 0 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 401a2c0e..a04ed9a0 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -672,6 +672,43 @@ impl Map1 for AvgPool2D {
}
}
+struct UpsampleNearest2D(usize, usize);
+
+impl Map1 for UpsampleNearest2D {
+ fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
+ // TODO: Specialized implementation for the case 2*h, 2*w?
+ let (dst_h, dst_w) = (self.0, self.1);
+ let (b_sz, c, src_h, src_w) = layout.shape().dims4()?;
+ let stride = layout.stride();
+ let (stride_h, stride_w) = (stride[2], stride[3]);
+ let src_index = layout.start_offset();
+ let scale_h = src_h as f64 / dst_h as f64;
+ let scale_w = src_w as f64 / dst_w as f64;
+ let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w];
+ let src_h_idxs = (0..src_h)
+ .map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize))
+ .collect::<Vec<_>>();
+ let src_w_idxs = (0..src_w)
+ .map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize))
+ .collect::<Vec<_>>();
+ for b_idx in 0..b_sz {
+ let dst = &mut dst[b_idx * c * dst_h * dst_w..];
+ let src_index = src_index + b_idx * stride[0];
+ for c_idx in 0..c {
+ let dst = &mut dst[c_idx * dst_h * dst_w..];
+ let src_index = src_index + c_idx * stride[1];
+ for (h_idx, src_h_idx) in src_h_idxs.iter().enumerate() {
+ for (w_idx, src_w_idx) in src_w_idxs.iter().enumerate() {
+ let src_index = src_index + src_h_idx * stride_h + src_w_idx * stride_w;
+ dst[h_idx * dst_w + w_idx] = src[src_index]
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
struct Gather<'a, I: IntDType> {
ids: &'a [I],
ids_l: &'a Layout,
@@ -1577,6 +1614,10 @@ impl BackendStorage for CpuStorage {
AvgPool2D(kernel_size, stride).map(self, layout)
}
+ fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
+ UpsampleNearest2D(h, w).map(self, layout)
+ }
+
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
// TODO: Have some generic map for functions that apply on num_traits::Float elements.
match self {