summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/wuerstchen
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-17 22:08:11 +0100
committerGitHub <noreply@github.com>2023-09-17 22:08:11 +0100
commitc2b866172abaf1d4b8d75273c4f4e28a16d872f0 (patch)
treee3bb84dc8a8cc2119e4268a30842c954c6cb6449 /candle-transformers/src/models/wuerstchen
parent06cc329e715cbb820343f9849a4a45c818cb8c5e (diff)
downloadcandle-c2b866172abaf1d4b8d75273c4f4e28a16d872f0.tar.gz
candle-c2b866172abaf1d4b8d75273c4f4e28a16d872f0.tar.bz2
candle-c2b866172abaf1d4b8d75273c4f4e28a16d872f0.zip
More Wuerstchen fixes. (#882)
* More Weurstchen fixes. * More shape fixes. * Add more of the prior specific bits. * Broadcast add. * Fix the clip config. * Add some masking options to the clip model.
Diffstat (limited to 'candle-transformers/src/models/wuerstchen')
-rw-r--r--candle-transformers/src/models/wuerstchen/common.rs6
-rw-r--r--candle-transformers/src/models/wuerstchen/diffnext.rs25
2 files changed, 18 insertions, 13 deletions
diff --git a/candle-transformers/src/models/wuerstchen/common.rs b/candle-transformers/src/models/wuerstchen/common.rs
index ee318d27..5337fdc6 100644
--- a/candle-transformers/src/models/wuerstchen/common.rs
+++ b/candle-transformers/src/models/wuerstchen/common.rs
@@ -75,9 +75,9 @@ impl Module for GlobalResponseNorm {
let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?;
let stand_div_norm =
agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?;
- (xs.broadcast_mul(&stand_div_norm)?
- .broadcast_mul(&self.gamma)
- + &self.beta)?
+ xs.broadcast_mul(&stand_div_norm)?
+ .broadcast_mul(&self.gamma)?
+ .broadcast_add(&self.beta)?
+ xs
}
}
diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs
index 70e4ba34..664251ed 100644
--- a/candle-transformers/src/models/wuerstchen/diffnext.rs
+++ b/candle-transformers/src/models/wuerstchen/diffnext.rs
@@ -68,7 +68,7 @@ struct DownBlock {
struct UpBlock {
sub_blocks: Vec<SubBlock>,
layer_norm: Option<WLayerNorm>,
- conv: Option<candle_nn::Conv2d>,
+ conv: Option<candle_nn::ConvTranspose2d>,
}
#[derive(Debug)]
@@ -152,20 +152,20 @@ impl WDiffNeXt {
stride: 2,
..Default::default()
};
- let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(1))?;
- (Some(layer_norm), Some(conv), 2)
+ let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp("0.1"))?;
+ (Some(layer_norm), Some(conv), 1)
} else {
(None, None, 0)
};
let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
let mut layer_i = start_layer_i;
- for j in 0..BLOCKS[i] {
+ for _j in 0..BLOCKS[i] {
let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?;
layer_i += 1;
let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
layer_i += 1;
- let attn_block = if j == 0 {
+ let attn_block = if i == 0 {
None
} else {
let attn_block =
@@ -190,7 +190,7 @@ impl WDiffNeXt {
let mut up_blocks = Vec::with_capacity(C_HIDDEN.len());
for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() {
- let vb = vb.pp("up_blocks").pp(i);
+ let vb = vb.pp("up_blocks").pp(C_HIDDEN.len() - 1 - i);
let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
let mut layer_i = 0;
for j in 0..BLOCKS[i] {
@@ -204,7 +204,7 @@ impl WDiffNeXt {
layer_i += 1;
let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
layer_i += 1;
- let attn_block = if j == 0 {
+ let attn_block = if i == 0 {
None
} else {
let attn_block =
@@ -221,12 +221,17 @@ impl WDiffNeXt {
}
let (layer_norm, conv) = if i > 0 {
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
- layer_i += 1;
- let cfg = candle_nn::Conv2dConfig {
+ let cfg = candle_nn::ConvTranspose2dConfig {
stride: 2,
..Default::default()
};
- let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(layer_i))?;
+ let conv = candle_nn::conv_transpose2d(
+ c_hidden,
+ C_HIDDEN[i - 1],
+ 2,
+ cfg,
+ vb.pp(layer_i).pp(1),
+ )?;
(Some(layer_norm), Some(conv))
} else {
(None, None)