summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/stable-diffusion')
-rw-r--r--candle-examples/examples/stable-diffusion/resnet.rs2
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d.rs2
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d_blocks.rs6
-rw-r--r--candle-examples/examples/stable-diffusion/vae.rs4
4 files changed, 10 insertions, 4 deletions
diff --git a/candle-examples/examples/stable-diffusion/resnet.rs b/candle-examples/examples/stable-diffusion/resnet.rs
index 94f436c8..172a9359 100644
--- a/candle-examples/examples/stable-diffusion/resnet.rs
+++ b/candle-examples/examples/stable-diffusion/resnet.rs
@@ -66,6 +66,7 @@ impl ResnetBlock2D {
let conv_cfg = nn::Conv2dConfig {
stride: 1,
padding: 1,
+ groups: 1,
};
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
@@ -79,6 +80,7 @@ impl ResnetBlock2D {
let conv_cfg = nn::Conv2dConfig {
stride: 1,
padding: 0,
+ groups: 1,
};
Some(conv2d(
in_channels,
diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs
index 6f568113..eb2dbf10 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d.rs
@@ -112,8 +112,8 @@ impl UNet2DConditionModel {
let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim;
let time_embed_dim = b_channels * 4;
let conv_cfg = nn::Conv2dConfig {
- stride: 1,
padding: 1,
+ ..Default::default()
};
let conv_in = conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
index b7adb2c0..65341e74 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
@@ -24,7 +24,11 @@ impl Downsample2D {
padding: usize,
) -> Result<Self> {
let conv = if use_conv {
- let config = nn::Conv2dConfig { stride: 2, padding };
+ let config = nn::Conv2dConfig {
+ stride: 2,
+ padding,
+ ..Default::default()
+ };
let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
Some(conv)
} else {
diff --git a/candle-examples/examples/stable-diffusion/vae.rs b/candle-examples/examples/stable-diffusion/vae.rs
index abba39fa..aa8e13a0 100644
--- a/candle-examples/examples/stable-diffusion/vae.rs
+++ b/candle-examples/examples/stable-diffusion/vae.rs
@@ -51,8 +51,8 @@ impl Encoder {
config: EncoderConfig,
) -> Result<Self> {
let conv_cfg = nn::Conv2dConfig {
- stride: 1,
padding: 1,
+ ..Default::default()
};
let conv_in = nn::conv2d(
in_channels,
@@ -182,8 +182,8 @@ impl Decoder {
let n_block_out_channels = config.block_out_channels.len();
let last_block_out_channels = *config.block_out_channels.last().unwrap();
let conv_cfg = nn::Conv2dConfig {
- stride: 1,
padding: 1,
+ ..Default::default()
};
let conv_in = nn::conv2d(
in_channels,