summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/backend.rs1
-rw-r--r--candle-core/src/backprop.rs2
-rw-r--r--candle-core/src/cpu_backend.rs93
-rw-r--r--candle-core/src/cuda_backend.rs4
-rw-r--r--candle-core/src/dummy_cuda_backend.rs4
-rw-r--r--candle-core/src/op.rs7
-rw-r--r--candle-core/src/shape.rs27
-rw-r--r--candle-core/src/storage.rs18
-rw-r--r--candle-core/src/tensor.rs16
9 files changed, 136 insertions, 36 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index a8e5ac52..4c31ca6f 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -46,6 +46,7 @@ pub trait BackendStorage: Sized {
) -> Result<Self>;
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
+ fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 0eab508e..2a60fe30 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -88,6 +88,7 @@ impl Tensor {
Op::Reshape(node)
| Op::UpsampleNearest2D(node)
| Op::AvgPool2D { arg: node, .. }
+ | Op::MaxPool2D { arg: node, .. }
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
@@ -172,6 +173,7 @@ impl Tensor {
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?,
+ Op::MaxPool2D { .. } => Err(Error::BackwardNotSupported { op: "max-pool2d" })?,
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d",
})?,
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 10c6cc4a..54f3f65b 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -660,8 +660,8 @@ impl Map1 for AvgPool2D {
let mut sum = T::zero();
for m in 0..k_h {
for n in 0..k_w {
- let m = k_h * h_idx + m;
- let n = k_w * w_idx + n;
+ let m = s_h * h_idx + m;
+ let n = s_w * w_idx + n;
sum += src[src_index + m * stride_h + n * stride_w]
}
}
@@ -674,6 +674,48 @@ impl Map1 for AvgPool2D {
}
}
+struct MaxPool2D((usize, usize), (usize, usize));
+
+impl Map1 for MaxPool2D {
+ fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
+ // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
+ let (k_h, k_w) = self.0;
+ let (s_h, s_w) = self.1;
+ let (b_sz, c, h, w) = layout.shape().dims4()?;
+ let stride = layout.stride();
+ let (stride_h, stride_w) = (stride[2], stride[3]);
+ let h_out = (h - k_h) / s_h + 1;
+ let w_out = (w - k_w) / s_w + 1;
+ let src_index = layout.start_offset();
+ let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
+ for b_idx in 0..b_sz {
+ let dst = &mut dst[b_idx * c * h_out * w_out..];
+ let src_index = src_index + b_idx * stride[0];
+ for c_idx in 0..c {
+ let dst = &mut dst[c_idx * h_out * w_out..];
+ let src_index = src_index + c_idx * stride[1];
+ for h_idx in 0..h_out {
+ for w_idx in 0..w_out {
+ let mut largest =
+ src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
+ for m in 0..k_h {
+ for n in 0..k_w {
+ let m = s_h * h_idx + m;
+ let n = s_w * w_idx + n;
+ if largest < src[src_index + m * stride_h + n * stride_w] {
+ largest = src[src_index + m * stride_h + n * stride_w]
+ }
+ }
+ }
+ dst[h_idx * w_out + w_idx] = largest;
+ }
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
struct UpsampleNearest2D(usize, usize);
impl Map1 for UpsampleNearest2D {
@@ -992,19 +1034,14 @@ impl<'a> Map2 for Conv1D<'a> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
let k = &k[k_l.start_offset()..];
- let inp_stride = inp_l.stride();
- let (inp_stride0, inp_stride) = if inp_stride.len() == 3 {
- (inp_stride[0], &inp_stride[1..])
- } else {
- (0, inp_stride) // This value never gets used anyway
- };
- let k_stride = k_l.stride();
+ let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
+ let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
let l_out = p.l_out();
let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1);
let mut dst = vec![T::zero(); dst_elems];
// The output shape is [b_size, c_out, l_out]
for b_idx in 0..p.b_size.unwrap_or(1) {
- let inp_idx = b_idx * inp_stride0;
+ let inp_idx = b_idx * inp_s0;
let dst_idx = b_idx * p.c_out * l_out;
for dst_c_idx in 0..p.c_out {
let dst_idx = dst_idx + dst_c_idx * l_out;
@@ -1016,11 +1053,8 @@ impl<'a> Map2 for Conv1D<'a> {
.saturating_sub(p.padding)
.min(p.l_in - 1);
for src_c_idx in 0..p.c_in {
- let inp_idx =
- inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1];
- let k_idx = dst_c_idx * k_stride[0]
- + src_c_idx * k_stride[1]
- + offset * k_stride[2];
+ let inp_idx = inp_idx + src_c_idx * inp_s1 + src_l * inp_s2;
+ let k_idx = dst_c_idx * k_s0 + src_c_idx * k_s1 + offset * k_s2;
d += inp[inp_idx] * k[k_idx]
}
}
@@ -1045,14 +1079,14 @@ impl<'a> Map2 for Conv2D<'a> {
) -> Result<Vec<T>> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
- let inp_stride = inp_l.stride();
+ let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
let k = &k[k_l.start_offset()..];
- let k_stride = k_l.stride();
+ let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
let (out_h, out_w) = (p.out_h(), p.out_w());
let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
for b_idx in 0..p.b_size {
- let inp_idx = b_idx * inp_stride[0];
+ let inp_idx = b_idx * inp_s0;
let dst_idx = b_idx * p.c_out * out_h * out_w;
for dst_c_idx in 0..p.c_out {
let dst_idx = dst_idx + dst_c_idx * out_h * out_w;
@@ -1071,13 +1105,13 @@ impl<'a> Map2 for Conv2D<'a> {
.min(p.i_w - 1);
for src_c_idx in 0..p.c_in {
let inp_idx = inp_idx
- + src_c_idx * inp_stride[1]
- + src_h * inp_stride[2]
- + src_w * inp_stride[3];
- let k_idx = dst_c_idx * k_stride[0]
- + src_c_idx * k_stride[1]
- + offset_h * k_stride[2]
- + offset_w * k_stride[3];
+ + src_c_idx * inp_s1
+ + src_h * inp_s2
+ + src_w * inp_s3;
+ let k_idx = dst_c_idx * k_s0
+ + src_c_idx * k_s1
+ + offset_h * k_s2
+ + offset_w * k_s3;
d += inp[inp_idx] * k[k_idx]
}
}
@@ -1672,6 +1706,15 @@ impl BackendStorage for CpuStorage {
AvgPool2D(kernel_size, stride).map(self, layout)
}
+ fn max_pool2d(
+ &self,
+ layout: &Layout,
+ kernel_size: (usize, usize),
+ stride: (usize, usize),
+ ) -> Result<Self> {
+ MaxPool2D(kernel_size, stride).map(self, layout)
+ }
+
fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
UpsampleNearest2D(h, w).map(self, layout)
}
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 727ea073..e51cc05d 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1395,6 +1395,10 @@ impl BackendStorage for CudaStorage {
todo!()
}
+ fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
+ todo!()
+ }
+
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
todo!()
}
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index ae4dd09f..870a87cd 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -134,6 +134,10 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
+ fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index aea8b733..f99d8adc 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -93,6 +93,13 @@ pub enum Op {
kernel_size: (usize, usize),
stride: (usize, usize),
},
+
+ MaxPool2D {
+ arg: Tensor,
+ kernel_size: (usize, usize),
+ stride: (usize, usize),
+ },
+
UpsampleNearest2D(Tensor),
Cat(Vec<Tensor>, usize),
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index a5e21aad..83d11c09 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -79,20 +79,25 @@ impl From<Vec<usize>> for Shape {
macro_rules! extract_dims {
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
+ pub fn $fn_name(dims: &[usize]) -> Result<$out_type> {
+ if dims.len() != $cnt {
+ Err(Error::UnexpectedNumberOfDims {
+ expected: $cnt,
+ got: dims.len(),
+ shape: Shape::from(dims),
+ }
+ .bt())
+ } else {
+ Ok($dims(dims))
+ }
+ }
+
impl Shape {
pub fn $fn_name(&self) -> Result<$out_type> {
- if self.0.len() != $cnt {
- Err(Error::UnexpectedNumberOfDims {
- expected: $cnt,
- got: self.0.len(),
- shape: self.clone(),
- }
- .bt())
- } else {
- Ok($dims(&self.0))
- }
+ $fn_name(self.0.as_slice())
}
}
+
impl crate::Tensor {
pub fn $fn_name(&self) -> Result<$out_type> {
self.shape().$fn_name()
@@ -340,7 +345,7 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
}
}
-extract_dims!(dims0, 0, |_: &Vec<usize>| (), ());
+extract_dims!(dims0, 0, |_: &[usize]| (), ());
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
extract_dims!(
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index 3ed38e6a..791b65dd 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -311,6 +311,24 @@ impl Storage {
}
}
+ pub(crate) fn max_pool2d(
+ &self,
+ layout: &Layout,
+ kernel_size: (usize, usize),
+ stride: (usize, usize),
+ ) -> Result<Self> {
+ match self {
+ Storage::Cpu(storage) => {
+ let storage = storage.max_pool2d(layout, kernel_size, stride)?;
+ Ok(Self::Cpu(storage))
+ }
+ Self::Cuda(storage) => {
+ let storage = storage.max_pool2d(layout, kernel_size, stride)?;
+ Ok(Self::Cuda(storage))
+ }
+ }
+ }
+
pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index adba7376..c94c0390 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -872,6 +872,22 @@ impl Tensor {
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
}
+ pub fn max_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
+ let (n, c, h, w) = self.dims4()?;
+ // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
+ let h_out = (h - kernel_size.0) / stride.0 + 1;
+ let w_out = (w - kernel_size.1) / stride.1 + 1;
+ let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
+ arg,
+ kernel_size,
+ stride,
+ });
+ let storage = self
+ .storage()
+ .max_pool2d(self.layout(), kernel_size, stride)?;
+ Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
+ }
+
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
///
/// # Arguments