summaryrefslogtreecommitdiff
path: root/candle-examples/examples
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples')
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d_blocks.rs11
-rw-r--r--candle-examples/examples/stable-diffusion/utils.rs8
2 files changed, 4 insertions, 15 deletions
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
index 4d0c80a5..82d5fad5 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
@@ -36,7 +36,7 @@ impl Downsample2D {
impl Downsample2D {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match &self.conv {
- None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None),
+ None => xs.avg_pool2d((2, 2), (2, 2)),
Some(conv) => {
if self.padding == 0 {
let xs = xs
@@ -72,13 +72,10 @@ impl Upsample2D {
fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
let xs = match size {
None => {
- // The following does not work and it's tricky to pass no fixed
- // dimensions so hack our way around this.
- // xs.upsample_nearest2d(&[], Some(2.), Some(2.)
- let (_bsize, _channels, _h, _w) = xs.dims4()?;
- crate::utils::upsample_nearest2d(xs)? // [2 * h, 2 * w], Some(2.), Some(2.))
+ let (_bsize, _channels, h, w) = xs.dims4()?;
+ xs.upsample_nearest2d(2 * h, 2 * w)?
}
- Some((_h, _w)) => crate::utils::upsample_nearest2d(xs)?, // [h, w], None, None),
+ Some((h, w)) => xs.upsample_nearest2d(h, w)?,
};
self.conv.forward(&xs)
}
diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs
index 08b78c04..0c95cfef 100644
--- a/candle-examples/examples/stable-diffusion/utils.rs
+++ b/candle-examples/examples/stable-diffusion/utils.rs
@@ -1,13 +1,5 @@
use candle::{Device, Result, Tensor};
-pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> {
- todo!()
-}
-
-pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> {
- todo!()
-}
-
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
if steps < 1 {
candle::bail!("cannot use linspace with steps {steps} <= 1")