summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/wuerstchen
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-14 23:24:56 +0200
committerGitHub <noreply@github.com>2023-09-14 22:24:56 +0100
commit130fe5a087715fc4d7bf9b581ca7c11378736ac5 (patch)
treeeffd5d92b1dddace769b8e1944eab97a8364c84f /candle-transformers/src/models/wuerstchen
parent91ec546febee4c6333cd65d95e8fd09e94499024 (diff)
downloadcandle-130fe5a087715fc4d7bf9b581ca7c11378736ac5.tar.gz
candle-130fe5a087715fc4d7bf9b581ca7c11378736ac5.tar.bz2
candle-130fe5a087715fc4d7bf9b581ca7c11378736ac5.zip
Add the upblocks. (#853)
Diffstat (limited to 'candle-transformers/src/models/wuerstchen')
-rw-r--r--candle-transformers/src/models/wuerstchen/diffnext.rs53
-rw-r--r--candle-transformers/src/models/wuerstchen/prior.rs3
2 files changed, 52 insertions, 4 deletions
diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs
index 5e49437c..7289a54d 100644
--- a/candle-transformers/src/models/wuerstchen/diffnext.rs
+++ b/candle-transformers/src/models/wuerstchen/diffnext.rs
@@ -161,8 +161,57 @@ impl WDiffNeXt {
down_blocks.push(down_block)
}
- // TODO: populate.
- let up_blocks = Vec::with_capacity(C_HIDDEN.len());
+ let mut up_blocks = Vec::with_capacity(C_HIDDEN.len());
+ for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() {
+ let vb = vb.pp("up_blocks").pp(i);
+ let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
+ let mut layer_i = 0;
+ for j in 0..BLOCKS[i] {
+ let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
+ let c_skip_res = if i < BLOCKS.len() - 1 && j == 0 {
+ c_hidden + c_skip
+ } else {
+ c_skip
+ };
+ let res_block = ResBlockStageB::new(c_hidden, c_skip_res, 3, vb.pp(layer_i))?;
+ layer_i += 1;
+ let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
+ layer_i += 1;
+ let attn_block = if j == 0 {
+ None
+ } else {
+ let attn_block =
+ AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?;
+ layer_i += 1;
+ Some(attn_block)
+ };
+ let sub_block = SubBlock {
+ res_block,
+ ts_block,
+ attn_block,
+ };
+ sub_blocks.push(sub_block)
+ }
+ let (layer_norm, conv, start_layer_i) = if i > 0 {
+ let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(layer_i))?;
+ layer_i += 1;
+ let cfg = candle_nn::Conv2dConfig {
+ stride: 2,
+ ..Default::default()
+ };
+ let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(layer_i))?;
+ layer_i += 1;
+ (Some(layer_norm), Some(conv), 2)
+ } else {
+ (None, None, 0)
+ };
+ let up_block = UpBlock {
+ layer_norm,
+ conv,
+ sub_blocks,
+ };
+ up_blocks.push(up_block)
+ }
let clf_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("clf.0"))?;
let clf_conv = candle_nn::conv2d(
diff --git a/candle-transformers/src/models/wuerstchen/prior.rs b/candle-transformers/src/models/wuerstchen/prior.rs
index eea70a02..5dd03778 100644
--- a/candle-transformers/src/models/wuerstchen/prior.rs
+++ b/candle-transformers/src/models/wuerstchen/prior.rs
@@ -85,10 +85,9 @@ impl WPrior {
pub fn forward(&self, xs: &Tensor, r: &Tensor, c: &Tensor) -> Result<Tensor> {
let x_in = xs;
let mut xs = xs.apply(&self.projection)?;
- // TODO: leaky relu
let c_embed = c
.apply(&self.cond_mapper_lin1)?
- .relu()?
+ .apply(&|xs: &_| candle_nn::ops::leaky_relu(xs, 0.2))?
.apply(&self.cond_mapper_lin2)?;
let r_embed = self.gen_r_embedding(r)?;
for block in self.blocks.iter() {