summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/wuerstchen/diffnext.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/wuerstchen/diffnext.rs')
-rw-r--r--candle-transformers/src/models/wuerstchen/diffnext.rs19
1 files changed, 11 insertions, 8 deletions
diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs
index 74e1836c..001b35d7 100644
--- a/candle-transformers/src/models/wuerstchen/diffnext.rs
+++ b/candle-transformers/src/models/wuerstchen/diffnext.rs
@@ -1,5 +1,4 @@
-#![allow(unused)]
-use super::common::{AttnBlock, GlobalResponseNorm, ResBlock, TimestepBlock, WLayerNorm};
+use super::common::{AttnBlock, GlobalResponseNorm, TimestepBlock, WLayerNorm};
use candle::{DType, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
@@ -223,7 +222,7 @@ impl WDiffNeXt {
};
sub_blocks.push(sub_block)
}
- let (layer_norm, conv, start_layer_i) = if i > 0 {
+ let (layer_norm, conv) = if i > 0 {
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(layer_i))?;
layer_i += 1;
let cfg = candle_nn::Conv2dConfig {
@@ -231,10 +230,9 @@ impl WDiffNeXt {
..Default::default()
};
let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(layer_i))?;
- layer_i += 1;
- (Some(layer_norm), Some(conv), 2)
+ (Some(layer_norm), Some(conv))
} else {
- (None, None, 0)
+ (None, None)
};
let up_block = UpBlock {
layer_norm,
@@ -337,7 +335,7 @@ impl WDiffNeXt {
level_outputs.reverse();
for (i, up_block) in self.up_blocks.iter().enumerate() {
- let skip = match &self.effnet_mappers[self.down_blocks.len() + i] {
+ let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] {
None => None,
Some(m) => {
let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
@@ -350,7 +348,12 @@ impl WDiffNeXt {
} else {
None
};
- xs = block.res_block.forward(&xs, skip)?;
+ let skip = match (skip, effnet_c.as_ref()) {
+ (Some(skip), Some(effnet_c)) => Some(Tensor::cat(&[skip, effnet_c], 1)?),
+ (None, Some(skip)) | (Some(skip), None) => Some(skip.clone()),
+ (None, None) => None,
+ };
+ xs = block.res_block.forward(&xs, skip.as_ref())?;
xs = block.ts_block.forward(&xs, &r_embed)?;
if let Some(attn_block) = &block.attn_block {
xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;