diff options
Diffstat (limited to 'candle-transformers/src/models/wuerstchen/diffnext.rs')
-rw-r--r-- | candle-transformers/src/models/wuerstchen/diffnext.rs | 23 |
1 files changed, 10 insertions, 13 deletions
diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs index 001b35d7..70e4ba34 100644 --- a/candle-transformers/src/models/wuerstchen/diffnext.rs +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -19,7 +19,7 @@ impl ResBlockStageB { ..Default::default() }; let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?; - let norm = WLayerNorm::new(c, vb.pp("norm"))?; + let norm = WLayerNorm::new(c)?; let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?; let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?; let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?; @@ -75,7 +75,7 @@ struct UpBlock { pub struct WDiffNeXt { clip_mapper: candle_nn::Linear, effnet_mappers: Vec<Option<candle_nn::Conv2d>>, - seq_norm: candle_nn::LayerNorm, + seq_norm: WLayerNorm, embedding_conv: candle_nn::Conv2d, embedding_ln: WLayerNorm, down_blocks: Vec<DownBlock>, @@ -98,7 +98,7 @@ impl WDiffNeXt { ) -> Result<Self> { const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280]; const BLOCKS: [usize; 4] = [4, 4, 14, 4]; - const NHEAD: [usize; 4] = [0, 10, 20, 20]; + const NHEAD: [usize; 4] = [1, 10, 20, 20]; const INJECT_EFFNET: [bool; 4] = [false, true, true, true]; const EFFNET_EMBD: usize = 16; @@ -133,24 +133,21 @@ impl WDiffNeXt { }; effnet_mappers.push(c) } - let cfg = candle_nn::layer_norm::LayerNormConfig { - ..Default::default() - }; - let seq_norm = candle_nn::layer_norm(c_cond, cfg, vb.pp("seq_norm"))?; - let embedding_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("embedding.1"))?; + let seq_norm = WLayerNorm::new(c_cond)?; + let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?; let embedding_conv = candle_nn::conv2d( c_in * patch_size * patch_size, - C_HIDDEN[1], + C_HIDDEN[0], 1, Default::default(), - vb.pp("embedding.2"), + vb.pp("embedding.1"), )?; let mut down_blocks = Vec::with_capacity(C_HIDDEN.len()); for (i, &c_hidden) in C_HIDDEN.iter().enumerate() { let vb = vb.pp("down_blocks").pp(i); let (layer_norm, conv, start_layer_i) = if i > 0 { - let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(0))?; + let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?; let cfg = candle_nn::Conv2dConfig { stride: 2, ..Default::default() @@ -223,7 +220,7 @@ impl WDiffNeXt { sub_blocks.push(sub_block) } let (layer_norm, conv) = if i > 0 { - let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(layer_i))?; + let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?; layer_i += 1; let cfg = candle_nn::Conv2dConfig { stride: 2, @@ -242,7 +239,7 @@ impl WDiffNeXt { up_blocks.push(up_block) } - let clf_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("clf.0"))?; + let clf_ln = WLayerNorm::new(C_HIDDEN[0])?; let clf_conv = candle_nn::conv2d( C_HIDDEN[0], 2 * c_out * patch_size * patch_size, |