diff options
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/musicgen/encodec_model.rs | 12 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/resnet.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/unet_2d.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/unet_2d_blocks.rs | 6 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/vae.rs | 4 | ||||
-rw-r--r-- | candle-examples/examples/whisper/model.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/yolo-v3/darknet.rs | 6 | ||||
-rw-r--r-- | candle-examples/examples/yolo-v8/main.rs | 6 |
8 files changed, 32 insertions, 8 deletions
diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs index 9c966497..e7712bf3 100644 --- a/candle-examples/examples/musicgen/encodec_model.rs +++ b/candle-examples/examples/musicgen/encodec_model.rs @@ -274,14 +274,22 @@ impl EncodecConv1d { in_c, out_c, kernel_size, - Conv1dConfig { padding: 0, stride }, + Conv1dConfig { + padding: 0, + stride, + groups: 1, + }, vb.pp("conv"), )?, NormType::None => conv1d( in_c, out_c, kernel_size, - Conv1dConfig { padding: 0, stride }, + Conv1dConfig { + padding: 0, + stride, + groups: 1, + }, vb.pp("conv"), )?, }; diff --git a/candle-examples/examples/stable-diffusion/resnet.rs b/candle-examples/examples/stable-diffusion/resnet.rs index 94f436c8..172a9359 100644 --- a/candle-examples/examples/stable-diffusion/resnet.rs +++ b/candle-examples/examples/stable-diffusion/resnet.rs @@ -66,6 +66,7 @@ impl ResnetBlock2D { let conv_cfg = nn::Conv2dConfig { stride: 1, padding: 1, + groups: 1, }; let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?; let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?; @@ -79,6 +80,7 @@ impl ResnetBlock2D { let conv_cfg = nn::Conv2dConfig { stride: 1, padding: 0, + groups: 1, }; Some(conv2d( in_channels, diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs index 6f568113..eb2dbf10 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d.rs @@ -112,8 +112,8 @@ impl UNet2DConditionModel { let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim; let time_embed_dim = b_channels * 4; let conv_cfg = nn::Conv2dConfig { - stride: 1, padding: 1, + ..Default::default() }; let conv_in = conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?; diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs index b7adb2c0..65341e74 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs @@ -24,7 +24,11 @@ impl Downsample2D { padding: usize, ) -> Result<Self> { let conv = if use_conv { - let config = nn::Conv2dConfig { stride: 2, padding }; + let config = nn::Conv2dConfig { + stride: 2, + padding, + ..Default::default() + }; let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?; Some(conv) } else { diff --git a/candle-examples/examples/stable-diffusion/vae.rs b/candle-examples/examples/stable-diffusion/vae.rs index abba39fa..aa8e13a0 100644 --- a/candle-examples/examples/stable-diffusion/vae.rs +++ b/candle-examples/examples/stable-diffusion/vae.rs @@ -51,8 +51,8 @@ impl Encoder { config: EncoderConfig, ) -> Result<Self> { let conv_cfg = nn::Conv2dConfig { - stride: 1, padding: 1, + ..Default::default() }; let conv_in = nn::conv2d( in_channels, @@ -182,8 +182,8 @@ impl Decoder { let n_block_out_channels = config.block_out_channels.len(); let last_block_out_channels = *config.block_out_channels.last().unwrap(); let conv_cfg = nn::Conv2dConfig { - stride: 1, padding: 1, + ..Default::default() }; let conv_in = nn::conv2d( in_channels, diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index 553bd93b..4ccc79f7 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -308,10 +308,12 @@ impl AudioEncoder { let cfg1 = Conv1dConfig { padding: 1, stride: 1, + groups: 1, }; let cfg2 = Conv1dConfig { padding: 1, stride: 2, + groups: 1, }; let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; diff --git a/candle-examples/examples/yolo-v3/darknet.rs b/candle-examples/examples/yolo-v3/darknet.rs index d0392308..de8fcf09 100644 --- a/candle-examples/examples/yolo-v3/darknet.rs +++ b/candle-examples/examples/yolo-v3/darknet.rs @@ -128,7 +128,11 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl) } Some(_) | None => (None, true), }; - let conv_cfg = candle_nn::Conv2dConfig { stride, padding }; + let conv_cfg = candle_nn::Conv2dConfig { + stride, + padding, + groups: 1, + }; let conv = if bias { conv2d(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))? } else { diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs index 616e04ed..3b9c1ce9 100644 --- a/candle-examples/examples/yolo-v8/main.rs +++ b/candle-examples/examples/yolo-v8/main.rs @@ -101,7 +101,11 @@ impl ConvBlock { padding: Option<usize>, ) -> Result<Self> { let padding = padding.unwrap_or(k / 2); - let cfg = Conv2dConfig { padding, stride }; + let cfg = Conv2dConfig { + padding, + stride, + groups: 1, + }; let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?; let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?; Ok(Self { conv, bn }) |