summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/wuerstchen/diffnext.rs
diff options
context:
space:
mode:
authorEvgeny Igumnov <igumnovnsk@gmail.com>2023-09-22 11:01:23 +0600
committerGitHub <noreply@github.com>2023-09-22 11:01:23 +0600
commit4ac6039a42b8125f7888709fb718bfd41a73f2ac (patch)
treef4bc165de51f258a9bf58cac4150c99e512fba01 /candle-transformers/src/models/wuerstchen/diffnext.rs
parent52a60ca3ad3f7e7b6da8e915a5a052d5bef10999 (diff)
parenta96878f2357fbcebf9db8747dcbb55bc8200d8ab (diff)
downloadcandle-4ac6039a42b8125f7888709fb718bfd41a73f2ac.tar.gz
candle-4ac6039a42b8125f7888709fb718bfd41a73f2ac.tar.bz2
candle-4ac6039a42b8125f7888709fb718bfd41a73f2ac.zip
Merge branch 'main' into book-trainin-simplified
Diffstat (limited to 'candle-transformers/src/models/wuerstchen/diffnext.rs')
-rw-r--r--candle-transformers/src/models/wuerstchen/diffnext.rs396
1 files changed, 396 insertions, 0 deletions
diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs
new file mode 100644
index 00000000..64a48c8a
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/diffnext.rs
@@ -0,0 +1,396 @@
+use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm};
+use candle::{DType, Module, Result, Tensor, D};
+use candle_nn::VarBuilder;
+
+#[derive(Debug)]
+pub struct ResBlockStageB {
+ depthwise: candle_nn::Conv2d,
+ norm: WLayerNorm,
+ channelwise_lin1: candle_nn::Linear,
+ channelwise_grn: GlobalResponseNorm,
+ channelwise_lin2: candle_nn::Linear,
+}
+
+impl ResBlockStageB {
+ pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> {
+ let cfg = candle_nn::Conv2dConfig {
+ groups: c,
+ padding: ksize / 2,
+ ..Default::default()
+ };
+ let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?;
+ let norm = WLayerNorm::new(c)?;
+ let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?;
+ let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?;
+ let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
+ Ok(Self {
+ depthwise,
+ norm,
+ channelwise_lin1,
+ channelwise_grn,
+ channelwise_lin2,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> {
+ let x_res = xs;
+ let xs = xs.apply(&self.depthwise)?.apply(&self.norm)?;
+ let xs = match x_skip {
+ None => xs.clone(),
+ Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?,
+ };
+ let xs = xs
+ .permute((0, 2, 3, 1))?
+ .contiguous()?
+ .apply(&self.channelwise_lin1)?
+ .gelu()?
+ .apply(&self.channelwise_grn)?
+ .apply(&self.channelwise_lin2)?
+ .permute((0, 3, 1, 2))?;
+ xs + x_res
+ }
+}
+
+#[derive(Debug)]
+struct SubBlock {
+ res_block: ResBlockStageB,
+ ts_block: TimestepBlock,
+ attn_block: Option<AttnBlock>,
+}
+
+#[derive(Debug)]
+struct DownBlock {
+ layer_norm: Option<WLayerNorm>,
+ conv: Option<candle_nn::Conv2d>,
+ sub_blocks: Vec<SubBlock>,
+}
+
+#[derive(Debug)]
+struct UpBlock {
+ sub_blocks: Vec<SubBlock>,
+ layer_norm: Option<WLayerNorm>,
+ conv: Option<candle_nn::ConvTranspose2d>,
+}
+
+#[derive(Debug)]
+pub struct WDiffNeXt {
+ clip_mapper: candle_nn::Linear,
+ effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
+ seq_norm: LayerNormNoWeights,
+ embedding_conv: candle_nn::Conv2d,
+ embedding_ln: WLayerNorm,
+ down_blocks: Vec<DownBlock>,
+ up_blocks: Vec<UpBlock>,
+ clf_ln: WLayerNorm,
+ clf_conv: candle_nn::Conv2d,
+ c_r: usize,
+ patch_size: usize,
+}
+
+impl WDiffNeXt {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ c_in: usize,
+ c_out: usize,
+ c_r: usize,
+ c_cond: usize,
+ clip_embd: usize,
+ patch_size: usize,
+ use_flash_attn: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
+ const BLOCKS: [usize; 4] = [4, 4, 14, 4];
+ const NHEAD: [usize; 4] = [1, 10, 20, 20];
+ const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
+ const EFFNET_EMBD: usize = 16;
+
+ let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
+ let mut effnet_mappers = Vec::with_capacity(2 * INJECT_EFFNET.len());
+ let vb_e = vb.pp("effnet_mappers");
+ for (i, &inject) in INJECT_EFFNET.iter().enumerate() {
+ let c = if inject {
+ Some(candle_nn::conv2d(
+ EFFNET_EMBD,
+ c_cond,
+ 1,
+ Default::default(),
+ vb_e.pp(i),
+ )?)
+ } else {
+ None
+ };
+ effnet_mappers.push(c)
+ }
+ for (i, &inject) in INJECT_EFFNET.iter().rev().enumerate() {
+ let c = if inject {
+ Some(candle_nn::conv2d(
+ EFFNET_EMBD,
+ c_cond,
+ 1,
+ Default::default(),
+ vb_e.pp(i + INJECT_EFFNET.len()),
+ )?)
+ } else {
+ None
+ };
+ effnet_mappers.push(c)
+ }
+ let seq_norm = LayerNormNoWeights::new(c_cond)?;
+ let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
+ let embedding_conv = candle_nn::conv2d(
+ c_in * patch_size * patch_size,
+ C_HIDDEN[0],
+ 1,
+ Default::default(),
+ vb.pp("embedding.1"),
+ )?;
+
+ let mut down_blocks = Vec::with_capacity(C_HIDDEN.len());
+ for (i, &c_hidden) in C_HIDDEN.iter().enumerate() {
+ let vb = vb.pp("down_blocks").pp(i);
+ let (layer_norm, conv, start_layer_i) = if i > 0 {
+ let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
+ let cfg = candle_nn::Conv2dConfig {
+ stride: 2,
+ ..Default::default()
+ };
+ let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp("0.1"))?;
+ (Some(layer_norm), Some(conv), 1)
+ } else {
+ (None, None, 0)
+ };
+ let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
+ let mut layer_i = start_layer_i;
+ for _j in 0..BLOCKS[i] {
+ let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
+ let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?;
+ layer_i += 1;
+ let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
+ layer_i += 1;
+ let attn_block = if i == 0 {
+ None
+ } else {
+ let attn_block = AttnBlock::new(
+ c_hidden,
+ c_cond,
+ NHEAD[i],
+ true,
+ use_flash_attn,
+ vb.pp(layer_i),
+ )?;
+ layer_i += 1;
+ Some(attn_block)
+ };
+ let sub_block = SubBlock {
+ res_block,
+ ts_block,
+ attn_block,
+ };
+ sub_blocks.push(sub_block)
+ }
+ let down_block = DownBlock {
+ layer_norm,
+ conv,
+ sub_blocks,
+ };
+ down_blocks.push(down_block)
+ }
+
+ let mut up_blocks = Vec::with_capacity(C_HIDDEN.len());
+ for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() {
+ let vb = vb.pp("up_blocks").pp(C_HIDDEN.len() - 1 - i);
+ let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
+ let mut layer_i = 0;
+ for j in 0..BLOCKS[i] {
+ let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
+ let c_skip_res = if i < BLOCKS.len() - 1 && j == 0 {
+ c_hidden + c_skip
+ } else {
+ c_skip
+ };
+ let res_block = ResBlockStageB::new(c_hidden, c_skip_res, 3, vb.pp(layer_i))?;
+ layer_i += 1;
+ let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
+ layer_i += 1;
+ let attn_block = if i == 0 {
+ None
+ } else {
+ let attn_block = AttnBlock::new(
+ c_hidden,
+ c_cond,
+ NHEAD[i],
+ true,
+ use_flash_attn,
+ vb.pp(layer_i),
+ )?;
+ layer_i += 1;
+ Some(attn_block)
+ };
+ let sub_block = SubBlock {
+ res_block,
+ ts_block,
+ attn_block,
+ };
+ sub_blocks.push(sub_block)
+ }
+ let (layer_norm, conv) = if i > 0 {
+ let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
+ let cfg = candle_nn::ConvTranspose2dConfig {
+ stride: 2,
+ ..Default::default()
+ };
+ let conv = candle_nn::conv_transpose2d(
+ c_hidden,
+ C_HIDDEN[i - 1],
+ 2,
+ cfg,
+ vb.pp(layer_i).pp(1),
+ )?;
+ (Some(layer_norm), Some(conv))
+ } else {
+ (None, None)
+ };
+ let up_block = UpBlock {
+ layer_norm,
+ conv,
+ sub_blocks,
+ };
+ up_blocks.push(up_block)
+ }
+
+ let clf_ln = WLayerNorm::new(C_HIDDEN[0])?;
+ let clf_conv = candle_nn::conv2d(
+ C_HIDDEN[0],
+ 2 * c_out * patch_size * patch_size,
+ 1,
+ Default::default(),
+ vb.pp("clf.1"),
+ )?;
+ Ok(Self {
+ clip_mapper,
+ effnet_mappers,
+ seq_norm,
+ embedding_conv,
+ embedding_ln,
+ down_blocks,
+ up_blocks,
+ clf_ln,
+ clf_conv,
+ c_r,
+ patch_size,
+ })
+ }
+
+ fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> {
+ const MAX_POSITIONS: usize = 10000;
+ let r = (r * MAX_POSITIONS as f64)?;
+ let half_dim = self.c_r / 2;
+ let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64;
+ let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)?
+ * -emb)?
+ .exp()?;
+ let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
+ let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?;
+ let emb = if self.c_r % 2 == 1 {
+ emb.pad_with_zeros(D::Minus1, 0, 1)?
+ } else {
+ emb
+ };
+ emb.to_dtype(r.dtype())
+ }
+
+ fn gen_c_embeddings(&self, clip: &Tensor) -> Result<Tensor> {
+ clip.apply(&self.clip_mapper)?.apply(&self.seq_norm)
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ r: &Tensor,
+ effnet: &Tensor,
+ clip: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ const EPS: f64 = 1e-3;
+
+ let r_embed = self.gen_r_embedding(r)?;
+ let clip = match clip {
+ None => None,
+ Some(clip) => Some(self.gen_c_embeddings(clip)?),
+ };
+ let x_in = xs;
+
+ let mut xs = xs
+ .apply(&|xs: &_| candle_nn::ops::pixel_unshuffle(xs, self.patch_size))?
+ .apply(&self.embedding_conv)?
+ .apply(&self.embedding_ln)?;
+
+ let mut level_outputs = Vec::new();
+ for (i, down_block) in self.down_blocks.iter().enumerate() {
+ if let Some(ln) = &down_block.layer_norm {
+ xs = xs.apply(ln)?
+ }
+ if let Some(conv) = &down_block.conv {
+ xs = xs.apply(conv)?
+ }
+ let skip = match &self.effnet_mappers[i] {
+ None => None,
+ Some(m) => {
+ let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
+ Some(m.forward(&effnet)?)
+ }
+ };
+ for block in down_block.sub_blocks.iter() {
+ xs = block.res_block.forward(&xs, skip.as_ref())?;
+ xs = block.ts_block.forward(&xs, &r_embed)?;
+ if let Some(attn_block) = &block.attn_block {
+ xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
+ }
+ }
+ level_outputs.push(xs.clone())
+ }
+ level_outputs.reverse();
+ let mut xs = level_outputs[0].clone();
+
+ for (i, up_block) in self.up_blocks.iter().enumerate() {
+ let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] {
+ None => None,
+ Some(m) => {
+ let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
+ Some(m.forward(&effnet)?)
+ }
+ };
+ for (j, block) in up_block.sub_blocks.iter().enumerate() {
+ let skip = if j == 0 && i > 0 {
+ Some(&level_outputs[i])
+ } else {
+ None
+ };
+ let skip = match (skip, effnet_c.as_ref()) {
+ (Some(skip), Some(effnet_c)) => Some(Tensor::cat(&[skip, effnet_c], 1)?),
+ (None, Some(skip)) | (Some(skip), None) => Some(skip.clone()),
+ (None, None) => None,
+ };
+ xs = block.res_block.forward(&xs, skip.as_ref())?;
+ xs = block.ts_block.forward(&xs, &r_embed)?;
+ if let Some(attn_block) = &block.attn_block {
+ xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
+ }
+ }
+ if let Some(ln) = &up_block.layer_norm {
+ xs = xs.apply(ln)?
+ }
+ if let Some(conv) = &up_block.conv {
+ xs = xs.apply(conv)?
+ }
+ }
+
+ let ab = xs
+ .apply(&self.clf_ln)?
+ .apply(&self.clf_conv)?
+ .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))?
+ .chunk(2, 1)?;
+ let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?;
+ (x_in - &ab[0])? / b
+ }
+}