diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-17 22:08:11 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-17 22:08:11 +0100 |
commit | c2b866172abaf1d4b8d75273c4f4e28a16d872f0 (patch) | |
tree | e3bb84dc8a8cc2119e4268a30842c954c6cb6449 /candle-transformers/src/models/wuerstchen | |
parent | 06cc329e715cbb820343f9849a4a45c818cb8c5e (diff) | |
download | candle-c2b866172abaf1d4b8d75273c4f4e28a16d872f0.tar.gz candle-c2b866172abaf1d4b8d75273c4f4e28a16d872f0.tar.bz2 candle-c2b866172abaf1d4b8d75273c4f4e28a16d872f0.zip |
More Wuerstchen fixes. (#882)
* More Weurstchen fixes.
* More shape fixes.
* Add more of the prior specific bits.
* Broadcast add.
* Fix the clip config.
* Add some masking options to the clip model.
Diffstat (limited to 'candle-transformers/src/models/wuerstchen')
-rw-r--r-- | candle-transformers/src/models/wuerstchen/common.rs | 6 | ||||
-rw-r--r-- | candle-transformers/src/models/wuerstchen/diffnext.rs | 25 |
2 files changed, 18 insertions, 13 deletions
diff --git a/candle-transformers/src/models/wuerstchen/common.rs b/candle-transformers/src/models/wuerstchen/common.rs index ee318d27..5337fdc6 100644 --- a/candle-transformers/src/models/wuerstchen/common.rs +++ b/candle-transformers/src/models/wuerstchen/common.rs @@ -75,9 +75,9 @@ impl Module for GlobalResponseNorm { let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?; let stand_div_norm = agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?; - (xs.broadcast_mul(&stand_div_norm)? - .broadcast_mul(&self.gamma) - + &self.beta)? + xs.broadcast_mul(&stand_div_norm)? + .broadcast_mul(&self.gamma)? + .broadcast_add(&self.beta)? + xs } } diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs index 70e4ba34..664251ed 100644 --- a/candle-transformers/src/models/wuerstchen/diffnext.rs +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -68,7 +68,7 @@ struct DownBlock { struct UpBlock { sub_blocks: Vec<SubBlock>, layer_norm: Option<WLayerNorm>, - conv: Option<candle_nn::Conv2d>, + conv: Option<candle_nn::ConvTranspose2d>, } #[derive(Debug)] @@ -152,20 +152,20 @@ impl WDiffNeXt { stride: 2, ..Default::default() }; - let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(1))?; - (Some(layer_norm), Some(conv), 2) + let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp("0.1"))?; + (Some(layer_norm), Some(conv), 1) } else { (None, None, 0) }; let mut sub_blocks = Vec::with_capacity(BLOCKS[i]); let mut layer_i = start_layer_i; - for j in 0..BLOCKS[i] { + for _j in 0..BLOCKS[i] { let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 }; let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?; layer_i += 1; let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?; layer_i += 1; - let attn_block = if j == 0 { + let attn_block = if i == 0 { None } else { let attn_block = @@ -190,7 +190,7 @@ impl WDiffNeXt { let mut up_blocks = Vec::with_capacity(C_HIDDEN.len()); for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() { - let vb = vb.pp("up_blocks").pp(i); + let vb = vb.pp("up_blocks").pp(C_HIDDEN.len() - 1 - i); let mut sub_blocks = Vec::with_capacity(BLOCKS[i]); let mut layer_i = 0; for j in 0..BLOCKS[i] { @@ -204,7 +204,7 @@ impl WDiffNeXt { layer_i += 1; let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?; layer_i += 1; - let attn_block = if j == 0 { + let attn_block = if i == 0 { None } else { let attn_block = @@ -221,12 +221,17 @@ impl WDiffNeXt { } let (layer_norm, conv) = if i > 0 { let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?; - layer_i += 1; - let cfg = candle_nn::Conv2dConfig { + let cfg = candle_nn::ConvTranspose2dConfig { stride: 2, ..Default::default() }; - let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(layer_i))?; + let conv = candle_nn::conv_transpose2d( + c_hidden, + C_HIDDEN[i - 1], + 2, + cfg, + vb.pp(layer_i).pp(1), + )?; (Some(layer_norm), Some(conv)) } else { (None, None) |