diff options
-rw-r--r-- | candle-core/src/backprop.rs | 12 | ||||
-rw-r--r-- | candle-core/src/op.rs | 15 | ||||
-rw-r--r-- | candle-core/src/storage.rs | 18 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 29 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/unet_2d_blocks.rs | 11 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/utils.rs | 8 | ||||
-rw-r--r-- | candle-nn/src/conv.rs | 12 |
7 files changed, 88 insertions, 17 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 2dff0a5a..0eab508e 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -55,6 +55,11 @@ impl Tensor { kernel: rhs, .. } + | Op::Conv2D { + arg: lhs, + kernel: rhs, + .. + } | Op::CustomOp2(lhs, rhs, _) | Op::Binary(lhs, rhs, _) | Op::Gather(lhs, rhs, _) @@ -81,6 +86,8 @@ impl Tensor { } } Op::Reshape(node) + | Op::UpsampleNearest2D(node) + | Op::AvgPool2D { arg: node, .. } | Op::Copy(node) | Op::Broadcast(node) | Op::Cmp(node, _) @@ -163,6 +170,11 @@ impl Tensor { *f_sum_grad = f_sum_grad.add(&f_grad)?; } Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?, + Op::Conv2D { .. } => Err(Error::BackwardNotSupported { op: "conv2d" })?, + Op::AvgPool2D { .. } => Err(Error::BackwardNotSupported { op: "avg-pool2d" })?, + Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { + op: "upsample-nearest2d", + })?, Op::Gather(arg, indexes, dim) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?; diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index b4ebca51..aea8b733 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -80,6 +80,21 @@ pub enum Op { stride: usize, }, + #[allow(dead_code)] + Conv2D { + arg: Tensor, + kernel: Tensor, + padding: usize, + stride: usize, + }, + + AvgPool2D { + arg: Tensor, + kernel_size: (usize, usize), + stride: (usize, usize), + }, + UpsampleNearest2D(Tensor), + Cat(Vec<Tensor>, usize), #[allow(dead_code)] // add is currently unused. diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 1e1ef305..cbca4fc4 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -266,6 +266,24 @@ impl Storage { } } + pub(crate) fn avg_pool2d( + &self, + _layout: &Layout, + _kernel_size: (usize, usize), + _stride: (usize, usize), + ) -> Result<Self> { + todo!() + } + + pub(crate) fn upsample_nearest2d( + &self, + _layout: &Layout, + _h: usize, + _w: usize, + ) -> Result<Self> { + todo!() + } + pub(crate) fn where_cond( &self, layout: &Layout, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f7bd894a..ffa4bf8c 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -817,6 +817,35 @@ impl Tensor { Ok(from_storage(storage, out_dims, op, false)) } + pub fn conv2d(&self, _kernel: &Self, _padding: usize, _stride: usize) -> Result<Self> { + todo!() + } + + pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> { + let (n, c, _h, _w) = self.dims4()?; + let op = BackpropOp::new1(self, Op::UpsampleNearest2D); + let storage = self + .storage() + .upsample_nearest2d(self.layout(), target_h, target_w)?; + Ok(from_storage(storage, (n, c, target_h, target_w), op, false)) + } + + pub fn avg_pool2d(&self, kernel_size: (usize, usize), stride: (usize, usize)) -> Result<Self> { + let (n, c, h, w) = self.dims4()?; + // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d + let h_out = (h - kernel_size.0) / stride.0 + 1; + let w_out = (w - kernel_size.1) / stride.1 + 1; + let op = BackpropOp::new1(self, |arg| Op::AvgPool2D { + arg, + kernel_size, + stride, + }); + let storage = self + .storage() + .avg_pool2d(self.layout(), kernel_size, stride)?; + Ok(from_storage(storage, (n, c, h_out, w_out), op, false)) + } + /// Returns the matrix-multiplication of the input tensor with the other provided tensor. /// /// # Arguments diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs index 4d0c80a5..82d5fad5 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs @@ -36,7 +36,7 @@ impl Downsample2D { impl Downsample2D { fn forward(&self, xs: &Tensor) -> Result<Tensor> { match &self.conv { - None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None), + None => xs.avg_pool2d((2, 2), (2, 2)), Some(conv) => { if self.padding == 0 { let xs = xs @@ -72,13 +72,10 @@ impl Upsample2D { fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> { let xs = match size { None => { - // The following does not work and it's tricky to pass no fixed - // dimensions so hack our way around this. - // xs.upsample_nearest2d(&[], Some(2.), Some(2.) - let (_bsize, _channels, _h, _w) = xs.dims4()?; - crate::utils::upsample_nearest2d(xs)? // [2 * h, 2 * w], Some(2.), Some(2.)) + let (_bsize, _channels, h, w) = xs.dims4()?; + xs.upsample_nearest2d(2 * h, 2 * w)? } - Some((_h, _w)) => crate::utils::upsample_nearest2d(xs)?, // [h, w], None, None), + Some((h, w)) => xs.upsample_nearest2d(h, w)?, }; self.conv.forward(&xs) } diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs index 08b78c04..0c95cfef 100644 --- a/candle-examples/examples/stable-diffusion/utils.rs +++ b/candle-examples/examples/stable-diffusion/utils.rs @@ -1,13 +1,5 @@ use candle::{Device, Result, Tensor}; -pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> { - todo!() -} - -pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> { - todo!() -} - pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> { if steps < 1 { candle::bail!("cannot use linspace with steps {steps} <= 1") diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index 6e1fcf51..67a80417 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -85,8 +85,16 @@ impl Conv2d { &self.config } - pub fn forward(&self, _x: &Tensor) -> Result<Tensor> { - todo!() + pub fn forward(&self, x: &Tensor) -> Result<Tensor> { + let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?; + match &self.bias { + None => Ok(x), + Some(bias) => { + let b = bias.dims1()?; + let bias = bias.reshape((1, b, 1, 1))?; + Ok(x.broadcast_add(&bias)?) + } + } } } |