diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-14 23:24:56 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-14 22:24:56 +0100 |
commit | 130fe5a087715fc4d7bf9b581ca7c11378736ac5 (patch) | |
tree | effd5d92b1dddace769b8e1944eab97a8364c84f /candle-transformers/src/models/wuerstchen | |
parent | 91ec546febee4c6333cd65d95e8fd09e94499024 (diff) | |
download | candle-130fe5a087715fc4d7bf9b581ca7c11378736ac5.tar.gz candle-130fe5a087715fc4d7bf9b581ca7c11378736ac5.tar.bz2 candle-130fe5a087715fc4d7bf9b581ca7c11378736ac5.zip |
Add the upblocks. (#853)
Diffstat (limited to 'candle-transformers/src/models/wuerstchen')
-rw-r--r-- | candle-transformers/src/models/wuerstchen/diffnext.rs | 53 | ||||
-rw-r--r-- | candle-transformers/src/models/wuerstchen/prior.rs | 3 |
2 files changed, 52 insertions, 4 deletions
diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs index 5e49437c..7289a54d 100644 --- a/candle-transformers/src/models/wuerstchen/diffnext.rs +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -161,8 +161,57 @@ impl WDiffNeXt { down_blocks.push(down_block) } - // TODO: populate. - let up_blocks = Vec::with_capacity(C_HIDDEN.len()); + let mut up_blocks = Vec::with_capacity(C_HIDDEN.len()); + for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() { + let vb = vb.pp("up_blocks").pp(i); + let mut sub_blocks = Vec::with_capacity(BLOCKS[i]); + let mut layer_i = 0; + for j in 0..BLOCKS[i] { + let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 }; + let c_skip_res = if i < BLOCKS.len() - 1 && j == 0 { + c_hidden + c_skip + } else { + c_skip + }; + let res_block = ResBlockStageB::new(c_hidden, c_skip_res, 3, vb.pp(layer_i))?; + layer_i += 1; + let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?; + layer_i += 1; + let attn_block = if j == 0 { + None + } else { + let attn_block = + AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?; + layer_i += 1; + Some(attn_block) + }; + let sub_block = SubBlock { + res_block, + ts_block, + attn_block, + }; + sub_blocks.push(sub_block) + } + let (layer_norm, conv, start_layer_i) = if i > 0 { + let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(layer_i))?; + layer_i += 1; + let cfg = candle_nn::Conv2dConfig { + stride: 2, + ..Default::default() + }; + let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(layer_i))?; + layer_i += 1; + (Some(layer_norm), Some(conv), 2) + } else { + (None, None, 0) + }; + let up_block = UpBlock { + layer_norm, + conv, + sub_blocks, + }; + up_blocks.push(up_block) + } let clf_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("clf.0"))?; let clf_conv = candle_nn::conv2d( diff --git a/candle-transformers/src/models/wuerstchen/prior.rs b/candle-transformers/src/models/wuerstchen/prior.rs index eea70a02..5dd03778 100644 --- a/candle-transformers/src/models/wuerstchen/prior.rs +++ b/candle-transformers/src/models/wuerstchen/prior.rs @@ -85,10 +85,9 @@ impl WPrior { pub fn forward(&self, xs: &Tensor, r: &Tensor, c: &Tensor) -> Result<Tensor> { let x_in = xs; let mut xs = xs.apply(&self.projection)?; - // TODO: leaky relu let c_embed = c .apply(&self.cond_mapper_lin1)? - .relu()? + .apply(&|xs: &_| candle_nn::ops::leaky_relu(xs, 0.2))? .apply(&self.cond_mapper_lin2)?; let r_embed = self.gen_r_embedding(r)?; for block in self.blocks.iter() { |