diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-07 17:15:38 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-07 16:15:38 +0100 |
commit | 2345b8ce3f8ebab6e04d6ea25f7c809efb037995 (patch) | |
tree | a1c74ed8d29d1f14d329eab6e1900749b041bbdd /candle-examples/examples/stable-diffusion/unet_2d_blocks.rs | |
parent | f53a333ea91233b41dd946c2c30213c79b4d1cb3 (diff) | |
download | candle-2345b8ce3f8ebab6e04d6ea25f7c809efb037995.tar.gz candle-2345b8ce3f8ebab6e04d6ea25f7c809efb037995.tar.bz2 candle-2345b8ce3f8ebab6e04d6ea25f7c809efb037995.zip |
Skeleton for the avg-pool2d and upsample-nearest2d ops. (#337)
* Skeleton for the avg-pool2d and upsample-nearest2d ops.
* Preliminary conv2d support.
Diffstat (limited to 'candle-examples/examples/stable-diffusion/unet_2d_blocks.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/unet_2d_blocks.rs | 11 |
1 files changed, 4 insertions, 7 deletions
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs index 4d0c80a5..82d5fad5 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs @@ -36,7 +36,7 @@ impl Downsample2D { impl Downsample2D { fn forward(&self, xs: &Tensor) -> Result<Tensor> { match &self.conv { - None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None), + None => xs.avg_pool2d((2, 2), (2, 2)), Some(conv) => { if self.padding == 0 { let xs = xs @@ -72,13 +72,10 @@ impl Upsample2D { fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> { let xs = match size { None => { - // The following does not work and it's tricky to pass no fixed - // dimensions so hack our way around this. - // xs.upsample_nearest2d(&[], Some(2.), Some(2.) - let (_bsize, _channels, _h, _w) = xs.dims4()?; - crate::utils::upsample_nearest2d(xs)? // [2 * h, 2 * w], Some(2.), Some(2.)) + let (_bsize, _channels, h, w) = xs.dims4()?; + xs.upsample_nearest2d(2 * h, 2 * w)? } - Some((_h, _w)) => crate::utils::upsample_nearest2d(xs)?, // [h, w], None, None), + Some((h, w)) => xs.upsample_nearest2d(h, w)?, }; self.conv.forward(&xs) } |