diff options
author | Evgeny Igumnov <igumnovnsk@gmail.com> | 2023-09-22 11:01:23 +0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-22 11:01:23 +0600 |
commit | 4ac6039a42b8125f7888709fb718bfd41a73f2ac (patch) | |
tree | f4bc165de51f258a9bf58cac4150c99e512fba01 /candle-transformers/src/models/wuerstchen/diffnext.rs | |
parent | 52a60ca3ad3f7e7b6da8e915a5a052d5bef10999 (diff) | |
parent | a96878f2357fbcebf9db8747dcbb55bc8200d8ab (diff) | |
download | candle-4ac6039a42b8125f7888709fb718bfd41a73f2ac.tar.gz candle-4ac6039a42b8125f7888709fb718bfd41a73f2ac.tar.bz2 candle-4ac6039a42b8125f7888709fb718bfd41a73f2ac.zip |
Merge branch 'main' into book-trainin-simplified
Diffstat (limited to 'candle-transformers/src/models/wuerstchen/diffnext.rs')
-rw-r--r-- | candle-transformers/src/models/wuerstchen/diffnext.rs | 396 |
1 files changed, 396 insertions, 0 deletions
diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs new file mode 100644 index 00000000..64a48c8a --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -0,0 +1,396 @@ +use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm}; +use candle::{DType, Module, Result, Tensor, D}; +use candle_nn::VarBuilder; + +#[derive(Debug)] +pub struct ResBlockStageB { + depthwise: candle_nn::Conv2d, + norm: WLayerNorm, + channelwise_lin1: candle_nn::Linear, + channelwise_grn: GlobalResponseNorm, + channelwise_lin2: candle_nn::Linear, +} + +impl ResBlockStageB { + pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> { + let cfg = candle_nn::Conv2dConfig { + groups: c, + padding: ksize / 2, + ..Default::default() + }; + let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?; + let norm = WLayerNorm::new(c)?; + let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?; + let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?; + let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?; + Ok(Self { + depthwise, + norm, + channelwise_lin1, + channelwise_grn, + channelwise_lin2, + }) + } + + pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> { + let x_res = xs; + let xs = xs.apply(&self.depthwise)?.apply(&self.norm)?; + let xs = match x_skip { + None => xs.clone(), + Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?, + }; + let xs = xs + .permute((0, 2, 3, 1))? + .contiguous()? + .apply(&self.channelwise_lin1)? + .gelu()? + .apply(&self.channelwise_grn)? + .apply(&self.channelwise_lin2)? + .permute((0, 3, 1, 2))?; + xs + x_res + } +} + +#[derive(Debug)] +struct SubBlock { + res_block: ResBlockStageB, + ts_block: TimestepBlock, + attn_block: Option<AttnBlock>, +} + +#[derive(Debug)] +struct DownBlock { + layer_norm: Option<WLayerNorm>, + conv: Option<candle_nn::Conv2d>, + sub_blocks: Vec<SubBlock>, +} + +#[derive(Debug)] +struct UpBlock { + sub_blocks: Vec<SubBlock>, + layer_norm: Option<WLayerNorm>, + conv: Option<candle_nn::ConvTranspose2d>, +} + +#[derive(Debug)] +pub struct WDiffNeXt { + clip_mapper: candle_nn::Linear, + effnet_mappers: Vec<Option<candle_nn::Conv2d>>, + seq_norm: LayerNormNoWeights, + embedding_conv: candle_nn::Conv2d, + embedding_ln: WLayerNorm, + down_blocks: Vec<DownBlock>, + up_blocks: Vec<UpBlock>, + clf_ln: WLayerNorm, + clf_conv: candle_nn::Conv2d, + c_r: usize, + patch_size: usize, +} + +impl WDiffNeXt { + #[allow(clippy::too_many_arguments)] + pub fn new( + c_in: usize, + c_out: usize, + c_r: usize, + c_cond: usize, + clip_embd: usize, + patch_size: usize, + use_flash_attn: bool, + vb: VarBuilder, + ) -> Result<Self> { + const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280]; + const BLOCKS: [usize; 4] = [4, 4, 14, 4]; + const NHEAD: [usize; 4] = [1, 10, 20, 20]; + const INJECT_EFFNET: [bool; 4] = [false, true, true, true]; + const EFFNET_EMBD: usize = 16; + + let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?; + let mut effnet_mappers = Vec::with_capacity(2 * INJECT_EFFNET.len()); + let vb_e = vb.pp("effnet_mappers"); + for (i, &inject) in INJECT_EFFNET.iter().enumerate() { + let c = if inject { + Some(candle_nn::conv2d( + EFFNET_EMBD, + c_cond, + 1, + Default::default(), + vb_e.pp(i), + )?) + } else { + None + }; + effnet_mappers.push(c) + } + for (i, &inject) in INJECT_EFFNET.iter().rev().enumerate() { + let c = if inject { + Some(candle_nn::conv2d( + EFFNET_EMBD, + c_cond, + 1, + Default::default(), + vb_e.pp(i + INJECT_EFFNET.len()), + )?) + } else { + None + }; + effnet_mappers.push(c) + } + let seq_norm = LayerNormNoWeights::new(c_cond)?; + let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?; + let embedding_conv = candle_nn::conv2d( + c_in * patch_size * patch_size, + C_HIDDEN[0], + 1, + Default::default(), + vb.pp("embedding.1"), + )?; + + let mut down_blocks = Vec::with_capacity(C_HIDDEN.len()); + for (i, &c_hidden) in C_HIDDEN.iter().enumerate() { + let vb = vb.pp("down_blocks").pp(i); + let (layer_norm, conv, start_layer_i) = if i > 0 { + let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?; + let cfg = candle_nn::Conv2dConfig { + stride: 2, + ..Default::default() + }; + let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp("0.1"))?; + (Some(layer_norm), Some(conv), 1) + } else { + (None, None, 0) + }; + let mut sub_blocks = Vec::with_capacity(BLOCKS[i]); + let mut layer_i = start_layer_i; + for _j in 0..BLOCKS[i] { + let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 }; + let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?; + layer_i += 1; + let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?; + layer_i += 1; + let attn_block = if i == 0 { + None + } else { + let attn_block = AttnBlock::new( + c_hidden, + c_cond, + NHEAD[i], + true, + use_flash_attn, + vb.pp(layer_i), + )?; + layer_i += 1; + Some(attn_block) + }; + let sub_block = SubBlock { + res_block, + ts_block, + attn_block, + }; + sub_blocks.push(sub_block) + } + let down_block = DownBlock { + layer_norm, + conv, + sub_blocks, + }; + down_blocks.push(down_block) + } + + let mut up_blocks = Vec::with_capacity(C_HIDDEN.len()); + for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() { + let vb = vb.pp("up_blocks").pp(C_HIDDEN.len() - 1 - i); + let mut sub_blocks = Vec::with_capacity(BLOCKS[i]); + let mut layer_i = 0; + for j in 0..BLOCKS[i] { + let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 }; + let c_skip_res = if i < BLOCKS.len() - 1 && j == 0 { + c_hidden + c_skip + } else { + c_skip + }; + let res_block = ResBlockStageB::new(c_hidden, c_skip_res, 3, vb.pp(layer_i))?; + layer_i += 1; + let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?; + layer_i += 1; + let attn_block = if i == 0 { + None + } else { + let attn_block = AttnBlock::new( + c_hidden, + c_cond, + NHEAD[i], + true, + use_flash_attn, + vb.pp(layer_i), + )?; + layer_i += 1; + Some(attn_block) + }; + let sub_block = SubBlock { + res_block, + ts_block, + attn_block, + }; + sub_blocks.push(sub_block) + } + let (layer_norm, conv) = if i > 0 { + let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?; + let cfg = candle_nn::ConvTranspose2dConfig { + stride: 2, + ..Default::default() + }; + let conv = candle_nn::conv_transpose2d( + c_hidden, + C_HIDDEN[i - 1], + 2, + cfg, + vb.pp(layer_i).pp(1), + )?; + (Some(layer_norm), Some(conv)) + } else { + (None, None) + }; + let up_block = UpBlock { + layer_norm, + conv, + sub_blocks, + }; + up_blocks.push(up_block) + } + + let clf_ln = WLayerNorm::new(C_HIDDEN[0])?; + let clf_conv = candle_nn::conv2d( + C_HIDDEN[0], + 2 * c_out * patch_size * patch_size, + 1, + Default::default(), + vb.pp("clf.1"), + )?; + Ok(Self { + clip_mapper, + effnet_mappers, + seq_norm, + embedding_conv, + embedding_ln, + down_blocks, + up_blocks, + clf_ln, + clf_conv, + c_r, + patch_size, + }) + } + + fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> { + const MAX_POSITIONS: usize = 10000; + let r = (r * MAX_POSITIONS as f64)?; + let half_dim = self.c_r / 2; + let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64; + let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)? + * -emb)? + .exp()?; + let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?; + let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?; + let emb = if self.c_r % 2 == 1 { + emb.pad_with_zeros(D::Minus1, 0, 1)? + } else { + emb + }; + emb.to_dtype(r.dtype()) + } + + fn gen_c_embeddings(&self, clip: &Tensor) -> Result<Tensor> { + clip.apply(&self.clip_mapper)?.apply(&self.seq_norm) + } + + pub fn forward( + &self, + xs: &Tensor, + r: &Tensor, + effnet: &Tensor, + clip: Option<&Tensor>, + ) -> Result<Tensor> { + const EPS: f64 = 1e-3; + + let r_embed = self.gen_r_embedding(r)?; + let clip = match clip { + None => None, + Some(clip) => Some(self.gen_c_embeddings(clip)?), + }; + let x_in = xs; + + let mut xs = xs + .apply(&|xs: &_| candle_nn::ops::pixel_unshuffle(xs, self.patch_size))? + .apply(&self.embedding_conv)? + .apply(&self.embedding_ln)?; + + let mut level_outputs = Vec::new(); + for (i, down_block) in self.down_blocks.iter().enumerate() { + if let Some(ln) = &down_block.layer_norm { + xs = xs.apply(ln)? + } + if let Some(conv) = &down_block.conv { + xs = xs.apply(conv)? + } + let skip = match &self.effnet_mappers[i] { + None => None, + Some(m) => { + let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?; + Some(m.forward(&effnet)?) + } + }; + for block in down_block.sub_blocks.iter() { + xs = block.res_block.forward(&xs, skip.as_ref())?; + xs = block.ts_block.forward(&xs, &r_embed)?; + if let Some(attn_block) = &block.attn_block { + xs = attn_block.forward(&xs, clip.as_ref().unwrap())?; + } + } + level_outputs.push(xs.clone()) + } + level_outputs.reverse(); + let mut xs = level_outputs[0].clone(); + + for (i, up_block) in self.up_blocks.iter().enumerate() { + let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] { + None => None, + Some(m) => { + let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?; + Some(m.forward(&effnet)?) + } + }; + for (j, block) in up_block.sub_blocks.iter().enumerate() { + let skip = if j == 0 && i > 0 { + Some(&level_outputs[i]) + } else { + None + }; + let skip = match (skip, effnet_c.as_ref()) { + (Some(skip), Some(effnet_c)) => Some(Tensor::cat(&[skip, effnet_c], 1)?), + (None, Some(skip)) | (Some(skip), None) => Some(skip.clone()), + (None, None) => None, + }; + xs = block.res_block.forward(&xs, skip.as_ref())?; + xs = block.ts_block.forward(&xs, &r_embed)?; + if let Some(attn_block) = &block.attn_block { + xs = attn_block.forward(&xs, clip.as_ref().unwrap())?; + } + } + if let Some(ln) = &up_block.layer_norm { + xs = xs.apply(ln)? + } + if let Some(conv) = &up_block.conv { + xs = xs.apply(conv)? + } + } + + let ab = xs + .apply(&self.clf_ln)? + .apply(&self.clf_conv)? + .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))? + .chunk(2, 1)?; + let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?; + (x_in - &ab[0])? / b + } +} |