diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-14 22:16:31 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-14 22:16:31 +0200 |
commit | 91ec546febee4c6333cd65d95e8fd09e94499024 (patch) | |
tree | d0f9422bb447018412744243076c0ade935fcb22 /candle-transformers/src/models/wuerstchen | |
parent | 0a647875ec7f9861ede7fa54713af50b4207ffb7 (diff) | |
download | candle-91ec546febee4c6333cd65d95e8fd09e94499024.tar.gz candle-91ec546febee4c6333cd65d95e8fd09e94499024.tar.bz2 candle-91ec546febee4c6333cd65d95e8fd09e94499024.zip |
More DiffNeXt. (#847)
* More DiffNeXt.
* Down blocks.
Diffstat (limited to 'candle-transformers/src/models/wuerstchen')
-rw-r--r-- | candle-transformers/src/models/wuerstchen/diffnext.rs | 146 |
1 files changed, 144 insertions, 2 deletions
diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs index 8e5099f6..5e49437c 100644 --- a/candle-transformers/src/models/wuerstchen/diffnext.rs +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -1,5 +1,5 @@ #![allow(unused)] -use super::common::{GlobalResponseNorm, ResBlock, TimestepBlock, WLayerNorm}; +use super::common::{AttnBlock, GlobalResponseNorm, ResBlock, TimestepBlock, WLayerNorm}; use candle::{DType, Module, Result, Tensor, D}; use candle_nn::VarBuilder; @@ -52,24 +52,54 @@ impl ResBlockStageB { } #[derive(Debug)] +struct SubBlock { + res_block: ResBlockStageB, + ts_block: TimestepBlock, + attn_block: Option<AttnBlock>, +} + +#[derive(Debug)] +struct DownBlock { + layer_norm: Option<WLayerNorm>, + conv: Option<candle_nn::Conv2d>, + sub_blocks: Vec<SubBlock>, +} + +#[derive(Debug)] +struct UpBlock { + sub_blocks: Vec<SubBlock>, + layer_norm: Option<WLayerNorm>, + conv: Option<candle_nn::Conv2d>, +} + +#[derive(Debug)] pub struct WDiffNeXt { clip_mapper: candle_nn::Linear, effnet_mappers: Vec<candle_nn::Conv2d>, seq_norm: candle_nn::LayerNorm, embedding_conv: candle_nn::Conv2d, embedding_ln: WLayerNorm, + down_blocks: Vec<DownBlock>, + up_blocks: Vec<UpBlock>, + clf_ln: WLayerNorm, + clf_conv: candle_nn::Conv2d, + c_r: usize, } impl WDiffNeXt { pub fn new( c_in: usize, c_out: usize, - vb: VarBuilder, + c_r: usize, c_cond: usize, clip_embd: usize, patch_size: usize, + vb: VarBuilder, ) -> Result<Self> { const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280]; + const BLOCKS: [usize; 4] = [4, 4, 14, 4]; + const NHEAD: [usize; 4] = [0, 10, 20, 20]; + const INJECT_EFFNET: [bool; 4] = [false, true, true, true]; let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?; let effnet_mappers = vec![]; @@ -85,12 +115,124 @@ impl WDiffNeXt { Default::default(), vb.pp("embedding.2"), )?; + + let mut down_blocks = Vec::with_capacity(C_HIDDEN.len()); + for (i, &c_hidden) in C_HIDDEN.iter().enumerate() { + let vb = vb.pp("down_blocks").pp(i); + let (layer_norm, conv, start_layer_i) = if i > 0 { + let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1], vb.pp(0))?; + let cfg = candle_nn::Conv2dConfig { + stride: 2, + ..Default::default() + }; + let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp(1))?; + (Some(layer_norm), Some(conv), 2) + } else { + (None, None, 0) + }; + let mut sub_blocks = Vec::with_capacity(BLOCKS[i]); + let mut layer_i = start_layer_i; + for j in 0..BLOCKS[i] { + let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 }; + let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?; + layer_i += 1; + let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?; + layer_i += 1; + let attn_block = if j == 0 { + None + } else { + let attn_block = + AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?; + layer_i += 1; + Some(attn_block) + }; + let sub_block = SubBlock { + res_block, + ts_block, + attn_block, + }; + sub_blocks.push(sub_block) + } + let down_block = DownBlock { + layer_norm, + conv, + sub_blocks, + }; + down_blocks.push(down_block) + } + + // TODO: populate. + let up_blocks = Vec::with_capacity(C_HIDDEN.len()); + + let clf_ln = WLayerNorm::new(C_HIDDEN[0], vb.pp("clf.0"))?; + let clf_conv = candle_nn::conv2d( + C_HIDDEN[0], + 2 * c_out * patch_size * patch_size, + 1, + Default::default(), + vb.pp("clf.1"), + )?; Ok(Self { clip_mapper, effnet_mappers, seq_norm, embedding_conv, embedding_ln, + down_blocks, + up_blocks, + clf_ln, + clf_conv, + c_r, }) } + + fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> { + const MAX_POSITIONS: usize = 10000; + let r = (r * MAX_POSITIONS as f64)?; + let half_dim = self.c_r / 2; + let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64; + let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)? + * -emb)? + .exp()?; + let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?; + let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?; + let emb = if self.c_r % 2 == 1 { + emb.pad_with_zeros(D::Minus1, 0, 1)? + } else { + emb + }; + emb.to_dtype(r.dtype()) + } + + fn gen_c_embeddings(&self, clip: &Tensor) -> Result<Tensor> { + clip.apply(&self.clip_mapper)?.apply(&self.seq_norm) + } + + pub fn forward( + &self, + xs: &Tensor, + r: &Tensor, + effnet: &Tensor, + clip: Option<&Tensor>, + ) -> Result<Tensor> { + const EPS: f64 = 1e-3; + + let r_embed = self.gen_r_embedding(r)?; + let clip = match clip { + None => None, + Some(clip) => Some(self.gen_c_embeddings(clip)?), + }; + let x_in = xs; + + // TODO: pixel unshuffle. + let xs = xs.apply(&self.embedding_conv)?.apply(&self.embedding_ln)?; + // TODO: down blocks + let level_outputs = xs.clone(); + // TODO: up blocks + let xs = level_outputs; + // TODO: pxel shuffle + let ab = xs.apply(&self.clf_ln)?.apply(&self.clf_conv)?.chunk(1, 2)?; + let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?; + (x_in - &ab[0])? / b + } } |