summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/wuerstchen/diffnext.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/wuerstchen/diffnext.rs')
-rw-r--r--candle-transformers/src/models/wuerstchen/diffnext.rs23
1 files changed, 10 insertions, 13 deletions
diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs
index 001b35d7..70e4ba34 100644
--- a/candle-transformers/src/models/wuerstchen/diffnext.rs
+++ b/candle-transformers/src/models/wuerstchen/diffnext.rs
@@ -19,7 +19,7 @@ impl ResBlockStageB {
..Default::default()
};
let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?;
- let norm = WLayerNorm::new(c, vb.pp("norm"))?;
+ let norm = WLayerNorm::new(c)?;
let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?;
let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?;
let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
@@ -75,7 +75,7 @@ struct UpBlock {
pub struct WDiffNeXt {
clip_mapper: candle_nn::Linear,
effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
- seq_norm: candle_nn::LayerNorm,
+ seq_norm: WLayerNorm,
embedding_conv: candle_nn::Conv2d,
embedding_ln: WLayerNorm,
down_blocks: Vec<DownBlock>,
@@ -98,7 +98,7 @@ impl WDiffNeXt {
) -> Result<Self> {
const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
const BLOCKS: [usize; 4] = [4, 4, 14, 4];
- const NHEAD: [usize; 4] = [0, 10, 20, 20];
+ const NHEAD: [usize; 4] = [1, 10, 20, 20];
const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
const EFFNET_EMBD: usize = 16;
@@ -133,24 +133,21 @@ impl WDiffNeXt {
};
effnet_mappers.push(c)
}
- let cfg = candle_nn::layer_norm::LayerNormConfig {
- ..Default::default()
- };
- let seq_norm = candle_nn::layer_norm(c_cond, cfg, vb.pp("seq_norm"))?;
- let embedding_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("embedding.1"))?;
+ let seq_norm = WLayerNorm::new(c_cond)?;
+ let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
let embedding_conv = candle_nn::conv2d(
c_in * patch_size * patch_size,
- C_HIDDEN[1],
+ C_HIDDEN[0],
1,
Default::default(),
- vb.pp("embedding.2"),
+ vb.pp("embedding.1"),
)?;
let mut down_blocks = Vec::with_capacity(C_HIDDEN.len());
for (i, &c_hidden) in C_HIDDEN.iter().enumerate() {
let vb = vb.pp("down_blocks").pp(i);
let (layer_norm, conv, start_layer_i) = if i > 0 {
- let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(0))?;
+ let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
let cfg = candle_nn::Conv2dConfig {
stride: 2,
..Default::default()
@@ -223,7 +220,7 @@ impl WDiffNeXt {
sub_blocks.push(sub_block)
}
let (layer_norm, conv) = if i > 0 {
- let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(layer_i))?;
+ let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
layer_i += 1;
let cfg = candle_nn::Conv2dConfig {
stride: 2,
@@ -242,7 +239,7 @@ impl WDiffNeXt {
up_blocks.push(up_block)
}
- let clf_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("clf.0"))?;
+ let clf_ln = WLayerNorm::new(C_HIDDEN[0])?;
let clf_conv = candle_nn::conv2d(
C_HIDDEN[0],
2 * c_out * patch_size * patch_size,