summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/backend.rs2
-rw-r--r--candle-core/src/cpu_backend.rs48
-rw-r--r--candle-core/src/cuda_backend.rs4
-rw-r--r--candle-core/src/dummy_cuda_backend.rs4
-rw-r--r--candle-core/src/storage.rs17
5 files changed, 71 insertions, 4 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index 345db0e5..307b56dc 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -37,6 +37,8 @@ pub trait BackendStorage: Sized {
_params: &crate::conv::ParamsConv1D,
) -> Result<Self>;
+ fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
+
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
fn scatter_add(
&self,
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 4aa2f880..401a2c0e 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -633,6 +633,45 @@ impl Map1 for Affine {
}
}
+struct AvgPool2D((usize, usize), (usize, usize));
+
+impl Map1 for AvgPool2D {
+ fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
+ // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html
+ let (k_h, k_w) = self.0;
+ let (s_h, s_w) = self.1;
+ let (b_sz, c, h, w) = layout.shape().dims4()?;
+ let stride = layout.stride();
+ let (stride_h, stride_w) = (stride[2], stride[3]);
+ let h_out = (h - k_h) / s_h + 1;
+ let w_out = (w - k_w) / s_w + 1;
+ let src_index = layout.start_offset();
+ let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
+ let scale = 1f64 / (k_h * k_w) as f64;
+ let scale = T::from_f64(scale);
+ for b_idx in 0..b_sz {
+ let dst = &mut dst[b_idx * c * h_out * w_out..];
+ let src_index = src_index + b_idx * stride[0];
+ for c_idx in 0..c {
+ let dst = &mut dst[c_idx * h_out * w_out..];
+ let src_index = src_index + c_idx * stride[1];
+ for h_idx in 0..h_out {
+ for w_idx in 0..w_out {
+ let mut sum = T::zero();
+ for m in 0..k_h {
+ for n in 0..k_w {
+ sum += src[src_index + m * stride_h + n * stride_w]
+ }
+ }
+ dst[h_idx * w_out + w_idx] = sum * scale;
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
struct Gather<'a, I: IntDType> {
ids: &'a [I],
ids_l: &'a Layout,
@@ -1529,6 +1568,15 @@ impl BackendStorage for CpuStorage {
Affine(mul, add).map(self, layout)
}
+ fn avg_pool2d(
+ &self,
+ layout: &Layout,
+ kernel_size: (usize, usize),
+ stride: (usize, usize),
+ ) -> Result<Self> {
+ AvgPool2D(kernel_size, stride).map(self, layout)
+ }
+
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
// TODO: Have some generic map for functions that apply on num_traits::Float elements.
match self {
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 7b4b358d..e71ecfce 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1381,6 +1381,10 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
+ fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
+ todo!()
+ }
+
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
let device = self.device().clone();
let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?;
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index 17d4a22e..2d5f955c 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -119,6 +119,10 @@ impl crate::backend::BackendStorage for CudaStorage {
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
+
+ fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
}
impl crate::backend::BackendDevice for CudaDevice {
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index cbca4fc4..47df689c 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -268,11 +268,20 @@ impl Storage {
pub(crate) fn avg_pool2d(
&self,
- _layout: &Layout,
- _kernel_size: (usize, usize),
- _stride: (usize, usize),
+ layout: &Layout,
+ kernel_size: (usize, usize),
+ stride: (usize, usize),
) -> Result<Self> {
- todo!()
+ match self {
+ Storage::Cpu(storage) => {
+ let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
+ Ok(Self::Cpu(storage))
+ }
+ Self::Cuda(storage) => {
+ let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
+ Ok(Self::Cuda(storage))
+ }
+ }
}
pub(crate) fn upsample_nearest2d(