summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/backprop.rs12
-rw-r--r--candle-core/src/op.rs15
-rw-r--r--candle-core/src/storage.rs18
-rw-r--r--candle-core/src/tensor.rs29
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d_blocks.rs11
-rw-r--r--candle-examples/examples/stable-diffusion/utils.rs8
-rw-r--r--candle-nn/src/conv.rs12
7 files changed, 88 insertions, 17 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 2dff0a5a..0eab508e 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -55,6 +55,11 @@ impl Tensor {
kernel: rhs,
..
}
+ | Op::Conv2D {
+ arg: lhs,
+ kernel: rhs,
+ ..
+ }
| Op::CustomOp2(lhs, rhs, _)
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
@@ -81,6 +86,8 @@ impl Tensor {
}
}
Op::Reshape(node)
+ | Op::UpsampleNearest2D(node)
+ | Op::AvgPool2D { arg: node, .. }
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
@@ -163,6 +170,11 @@ impl Tensor {
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
+ Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?,
+ Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?,
+ Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
+ op: "upsample-nearest2d",
+ })?,
Op::Gather(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index b4ebca51..aea8b733 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -80,6 +80,21 @@ pub enum Op {
stride: usize,
},
+ #[allow(dead_code)]
+ Conv2D {
+ arg: Tensor,
+ kernel: Tensor,
+ padding: usize,
+ stride: usize,
+ },
+
+ AvgPool2D {
+ arg: Tensor,
+ kernel_size: (usize, usize),
+ stride: (usize, usize),
+ },
+ UpsampleNearest2D(Tensor),
+
Cat(Vec<Tensor>, usize),
#[allow(dead_code)] // add is currently unused.
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index 1e1ef305..cbca4fc4 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -266,6 +266,24 @@ impl Storage {
}
}
+ pub(crate) fn avg_pool2d(
+ &self,
+ _layout: &Layout,
+ _kernel_size: (usize, usize),
+ _stride: (usize, usize),
+ ) -> Result<Self> {
+ todo!()
+ }
+
+ pub(crate) fn upsample_nearest2d(
+ &self,
+ _layout: &Layout,
+ _h: usize,
+ _w: usize,
+ ) -> Result<Self> {
+ todo!()
+ }
+
pub(crate) fn where_cond(
&self,
layout: &Layout,
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index f7bd894a..ffa4bf8c 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -817,6 +817,35 @@ impl Tensor {
Ok(from_storage(storage, out_dims, op, false))
}
+ pub fn conv2d(&self, _kernel: &Self, _padding: usize, _stride: usize) -> Result<Self> {
+ todo!()
+ }
+
+ pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
+ let (n, c, _h, _w) = self.dims4()?;
+ let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
+ let storage = self
+ .storage()
+ .upsample_nearest2d(self.layout(), target_h, target_w)?;
+ Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
+ }
+
+ pub fn avg_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> {
+ let (n, c, h, w) = self.dims4()?;
+ // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
+ let h_out = (h - kernel_size.0) / stride.0 + 1;
+ let w_out = (w - kernel_size.1) / stride.1 + 1;
+ let op = BackpropOp::new1(self, |arg| Op::AvgPool2D {
+ arg,
+ kernel_size,
+ stride,
+ });
+ let storage = self
+ .storage()
+ .avg_pool2d(self.layout(), kernel_size, stride)?;
+ Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
+ }
+
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
///
/// # Arguments
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
index 4d0c80a5..82d5fad5 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
@@ -36,7 +36,7 @@ impl Downsample2D {
impl Downsample2D {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match &self.conv {
- None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None),
+ None => xs.avg_pool2d((2, 2), (2, 2)),
Some(conv) => {
if self.padding == 0 {
let xs = xs
@@ -72,13 +72,10 @@ impl Upsample2D {
fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
let xs = match size {
None => {
- // The following does not work and it's tricky to pass no fixed
- // dimensions so hack our way around this.
- // xs.upsample_nearest2d(&[], Some(2.), Some(2.)
- let (_bsize, _channels, _h, _w) = xs.dims4()?;
- crate::utils::upsample_nearest2d(xs)? // [2 * h, 2 * w], Some(2.), Some(2.))
+ let (_bsize, _channels, h, w) = xs.dims4()?;
+ xs.upsample_nearest2d(2 * h, 2 * w)?
}
- Some((_h, _w)) => crate::utils::upsample_nearest2d(xs)?, // [h, w], None, None),
+ Some((h, w)) => xs.upsample_nearest2d(h, w)?,
};
self.conv.forward(&xs)
}
diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs
index 08b78c04..0c95cfef 100644
--- a/candle-examples/examples/stable-diffusion/utils.rs
+++ b/candle-examples/examples/stable-diffusion/utils.rs
@@ -1,13 +1,5 @@
use candle::{Device, Result, Tensor};
-pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> {
- todo!()
-}
-
-pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> {
- todo!()
-}
-
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
if steps < 1 {
candle::bail!("cannot use linspace with steps {steps} <= 1")
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs
index 6e1fcf51..67a80417 100644
--- a/candle-nn/src/conv.rs
+++ b/candle-nn/src/conv.rs
@@ -85,8 +85,16 @@ impl Conv2d {
&self.config
}
- pub fn forward(&self, _x: &Tensor) -> Result<Tensor> {
- todo!()
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?;
+ match &self.bias {
+ None => Ok(x),
+ Some(bias) => {
+ let b = bias.dims1()?;
+ let bias = bias.reshape((1, b, 1, 1))?;
+ Ok(x.broadcast_add(&bias)?)
+ }
+ }
}
}