use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm}; use candle::{DType, Module, Result, Tensor, D}; use candle_nn::VarBuilder; #[derive(Debug)] pub struct ResBlockStageB { depthwise: candle_nn::Conv2d, norm: WLayerNorm, channelwise_lin1: candle_nn::Linear, channelwise_grn: GlobalResponseNorm, channelwise_lin2: candle_nn::Linear, } impl ResBlockStageB { pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result { let cfg = candle_nn::Conv2dConfig { groups: c, padding: ksize / 2, ..Default::default() }; let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?; let norm = WLayerNorm::new(c)?; let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?; let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?; let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?; Ok(Self { depthwise, norm, channelwise_lin1, channelwise_grn, channelwise_lin2, }) } pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result { let x_res = xs; let xs = xs.apply(&self.depthwise)?.apply(&self.norm)?; let xs = match x_skip { None => xs.clone(), Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?, }; let xs = xs .permute((0, 2, 3, 1))? .contiguous()? .apply(&self.channelwise_lin1)? .gelu()? .apply(&self.channelwise_grn)? .apply(&self.channelwise_lin2)? .permute((0, 3, 1, 2))?; xs + x_res } } #[derive(Debug)] struct SubBlock { res_block: ResBlockStageB, ts_block: TimestepBlock, attn_block: Option, } #[derive(Debug)] struct DownBlock { layer_norm: Option, conv: Option, sub_blocks: Vec, } #[derive(Debug)] struct UpBlock { sub_blocks: Vec, layer_norm: Option, conv: Option, } #[derive(Debug)] pub struct WDiffNeXt { clip_mapper: candle_nn::Linear, effnet_mappers: Vec>, seq_norm: LayerNormNoWeights, embedding_conv: candle_nn::Conv2d, embedding_ln: WLayerNorm, down_blocks: Vec, up_blocks: Vec, clf_ln: WLayerNorm, clf_conv: candle_nn::Conv2d, c_r: usize, patch_size: usize, } impl WDiffNeXt { #[allow(clippy::too_many_arguments)] pub fn new( c_in: usize, c_out: usize, c_r: usize, c_cond: usize, clip_embd: usize, patch_size: usize, use_flash_attn: bool, vb: VarBuilder, ) -> Result { const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280]; const BLOCKS: [usize; 4] = [4, 4, 14, 4]; const NHEAD: [usize; 4] = [1, 10, 20, 20]; const INJECT_EFFNET: [bool; 4] = [false, true, true, true]; const EFFNET_EMBD: usize = 16; let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?; let mut effnet_mappers = Vec::with_capacity(2 * INJECT_EFFNET.len()); let vb_e = vb.pp("effnet_mappers"); for (i, &inject) in INJECT_EFFNET.iter().enumerate() { let c = if inject { Some(candle_nn::conv2d( EFFNET_EMBD, c_cond, 1, Default::default(), vb_e.pp(i), )?) } else { None }; effnet_mappers.push(c) } for (i, &inject) in INJECT_EFFNET.iter().rev().enumerate() { let c = if inject { Some(candle_nn::conv2d( EFFNET_EMBD, c_cond, 1, Default::default(), vb_e.pp(i + INJECT_EFFNET.len()), )?) } else { None }; effnet_mappers.push(c) } let seq_norm = LayerNormNoWeights::new(c_cond)?; let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?; let embedding_conv = candle_nn::conv2d( c_in * patch_size * patch_size, C_HIDDEN[0], 1, Default::default(), vb.pp("embedding.1"), )?; let mut down_blocks = Vec::with_capacity(C_HIDDEN.len()); for (i, &c_hidden) in C_HIDDEN.iter().enumerate() { let vb = vb.pp("down_blocks").pp(i); let (layer_norm, conv, start_layer_i) = if i > 0 { let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?; let cfg = candle_nn::Conv2dConfig { stride: 2, ..Default::default() }; let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp("0.1"))?; (Some(layer_norm), Some(conv), 1) } else { (None, None, 0) }; let mut sub_blocks = Vec::with_capacity(BLOCKS[i]); let mut layer_i = start_layer_i; for _j in 0..BLOCKS[i] { let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 }; let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?; layer_i += 1; let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?; layer_i += 1; let attn_block = if i == 0 { None } else { let attn_block = AttnBlock::new( c_hidden, c_cond, NHEAD[i], true, use_flash_attn, vb.pp(layer_i), )?; layer_i += 1; Some(attn_block) }; let sub_block = SubBlock { res_block, ts_block, attn_block, }; sub_blocks.push(sub_block) } let down_block = DownBlock { layer_norm, conv, sub_blocks, }; down_blocks.push(down_block) } let mut up_blocks = Vec::with_capacity(C_HIDDEN.len()); for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() { let vb = vb.pp("up_blocks").pp(C_HIDDEN.len() - 1 - i); let mut sub_blocks = Vec::with_capacity(BLOCKS[i]); let mut layer_i = 0; for j in 0..BLOCKS[i] { let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 }; let c_skip_res = if i < BLOCKS.len() - 1 && j == 0 { c_hidden + c_skip } else { c_skip }; let res_block = ResBlockStageB::new(c_hidden, c_skip_res, 3, vb.pp(layer_i))?; layer_i += 1; let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?; layer_i += 1; let attn_block = if i == 0 { None } else { let attn_block = AttnBlock::new( c_hidden, c_cond, NHEAD[i], true, use_flash_attn, vb.pp(layer_i), )?; layer_i += 1; Some(attn_block) }; let sub_block = SubBlock { res_block, ts_block, attn_block, }; sub_blocks.push(sub_block) } let (layer_norm, conv) = if i > 0 { let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?; let cfg = candle_nn::ConvTranspose2dConfig { stride: 2, ..Default::default() }; let conv = candle_nn::conv_transpose2d( c_hidden, C_HIDDEN[i - 1], 2, cfg, vb.pp(layer_i).pp(1), )?; (Some(layer_norm), Some(conv)) } else { (None, None) }; let up_block = UpBlock { layer_norm, conv, sub_blocks, }; up_blocks.push(up_block) } let clf_ln = WLayerNorm::new(C_HIDDEN[0])?; let clf_conv = candle_nn::conv2d( C_HIDDEN[0], 2 * c_out * patch_size * patch_size, 1, Default::default(), vb.pp("clf.1"), )?; Ok(Self { clip_mapper, effnet_mappers, seq_norm, embedding_conv, embedding_ln, down_blocks, up_blocks, clf_ln, clf_conv, c_r, patch_size, }) } fn gen_r_embedding(&self, r: &Tensor) -> Result { const MAX_POSITIONS: usize = 10000; let r = (r * MAX_POSITIONS as f64)?; let half_dim = self.c_r / 2; let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64; let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)? * -emb)? .exp()?; let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?; let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?; let emb = if self.c_r % 2 == 1 { emb.pad_with_zeros(D::Minus1, 0, 1)? } else { emb }; emb.to_dtype(r.dtype()) } fn gen_c_embeddings(&self, clip: &Tensor) -> Result { clip.apply(&self.clip_mapper)?.apply(&self.seq_norm) } pub fn forward( &self, xs: &Tensor, r: &Tensor, effnet: &Tensor, clip: Option<&Tensor>, ) -> Result { const EPS: f64 = 1e-3; let r_embed = self.gen_r_embedding(r)?; let clip = match clip { None => None, Some(clip) => Some(self.gen_c_embeddings(clip)?), }; let x_in = xs; let mut xs = xs .apply(&|xs: &_| candle_nn::ops::pixel_unshuffle(xs, self.patch_size))? .apply(&self.embedding_conv)? .apply(&self.embedding_ln)?; let mut level_outputs = Vec::new(); for (i, down_block) in self.down_blocks.iter().enumerate() { if let Some(ln) = &down_block.layer_norm { xs = xs.apply(ln)? } if let Some(conv) = &down_block.conv { xs = xs.apply(conv)? } let skip = match &self.effnet_mappers[i] { None => None, Some(m) => { let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?; Some(m.forward(&effnet)?) } }; for block in down_block.sub_blocks.iter() { xs = block.res_block.forward(&xs, skip.as_ref())?; xs = block.ts_block.forward(&xs, &r_embed)?; if let Some(attn_block) = &block.attn_block { xs = attn_block.forward(&xs, clip.as_ref().unwrap())?; } } level_outputs.push(xs.clone()) } level_outputs.reverse(); let mut xs = level_outputs[0].clone(); for (i, up_block) in self.up_blocks.iter().enumerate() { let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] { None => None, Some(m) => { let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?; Some(m.forward(&effnet)?) } }; for (j, block) in up_block.sub_blocks.iter().enumerate() { let skip = if j == 0 && i > 0 { Some(&level_outputs[i]) } else { None }; let skip = match (skip, effnet_c.as_ref()) { (Some(skip), Some(effnet_c)) => Some(Tensor::cat(&[skip, effnet_c], 1)?), (None, Some(skip)) | (Some(skip), None) => Some(skip.clone()), (None, None) => None, }; xs = block.res_block.forward(&xs, skip.as_ref())?; xs = block.ts_block.forward(&xs, &r_embed)?; if let Some(attn_block) = &block.attn_block { xs = attn_block.forward(&xs, clip.as_ref().unwrap())?; } } if let Some(ln) = &up_block.layer_norm { xs = xs.apply(ln)? } if let Some(conv) = &up_block.conv { xs = xs.apply(conv)? } } let ab = xs .apply(&self.clf_ln)? .apply(&self.clf_conv)? .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))? .chunk(2, 1)?; let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?; (x_in - &ab[0])? / b } }