diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-15 11:14:02 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-15 10:14:02 +0100 |
commit | 2746f2c4bec61d8f2c7947a297c46a74a3f799db (patch) | |
tree | 60a9d10a7504063a3a84f7a9ff8f10b3e0258a30 /candle-transformers/src/models/wuerstchen | |
parent | 81a36b8713721e2cfd098fad04972bf823ef0d5d (diff) | |
download | candle-2746f2c4bec61d8f2c7947a297c46a74a3f799db.tar.gz candle-2746f2c4bec61d8f2c7947a297c46a74a3f799db.tar.bz2 candle-2746f2c4bec61d8f2c7947a297c46a74a3f799db.zip |
DiffNeXt/unet (#859)
* DiffNeXt/unet
* Start adding the vae.
* VAE residual block.
* VAE forward pass.
* Add pixel shuffling.
* Actually use pixel shuffling.
Diffstat (limited to 'candle-transformers/src/models/wuerstchen')
-rw-r--r-- | candle-transformers/src/models/wuerstchen/diffnext.rs | 77 | ||||
-rw-r--r-- | candle-transformers/src/models/wuerstchen/mod.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/wuerstchen/paella_vq.rs | 111 |
3 files changed, 180 insertions, 9 deletions
diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs index 7289a54d..33ca192e 100644 --- a/candle-transformers/src/models/wuerstchen/diffnext.rs +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -75,7 +75,7 @@ struct UpBlock { #[derive(Debug)] pub struct WDiffNeXt { clip_mapper: candle_nn::Linear, - effnet_mappers: Vec<candle_nn::Conv2d>, + effnet_mappers: Vec<Option<candle_nn::Conv2d>>, seq_norm: candle_nn::LayerNorm, embedding_conv: candle_nn::Conv2d, embedding_ln: WLayerNorm, @@ -84,6 +84,7 @@ pub struct WDiffNeXt { clf_ln: WLayerNorm, clf_conv: candle_nn::Conv2d, c_r: usize, + patch_size: usize, } impl WDiffNeXt { @@ -102,6 +103,7 @@ impl WDiffNeXt { const INJECT_EFFNET: [bool; 4] = [false, true, true, true]; let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?; + // TODO: populate effnet_mappers let effnet_mappers = vec![]; let cfg = candle_nn::layer_norm::LayerNormConfig { ..Default::default() @@ -232,6 +234,7 @@ impl WDiffNeXt { clf_ln, clf_conv, c_r, + patch_size, }) } @@ -273,14 +276,70 @@ impl WDiffNeXt { }; let x_in = xs; - // TODO: pixel unshuffle. - let xs = xs.apply(&self.embedding_conv)?.apply(&self.embedding_ln)?; - // TODO: down blocks - let level_outputs = xs.clone(); - // TODO: up blocks - let xs = level_outputs; - // TODO: pxel shuffle - let ab = xs.apply(&self.clf_ln)?.apply(&self.clf_conv)?.chunk(1, 2)?; + let mut xs = xs + .apply(&|xs: &_| candle_nn::ops::pixel_unshuffle(xs, self.patch_size))? + .apply(&self.embedding_conv)? + .apply(&self.embedding_ln)?; + + let mut level_outputs = Vec::new(); + for (i, down_block) in self.down_blocks.iter().enumerate() { + if let Some(ln) = &down_block.layer_norm { + xs = xs.apply(ln)? + } + if let Some(conv) = &down_block.conv { + xs = xs.apply(conv)? + } + let skip = match &self.effnet_mappers[i] { + None => None, + Some(m) => { + let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?; + Some(m.forward(&effnet)?) + } + }; + for block in down_block.sub_blocks.iter() { + xs = block.res_block.forward(&xs, skip.as_ref())?; + xs = block.ts_block.forward(&xs, &r_embed)?; + if let Some(attn_block) = &block.attn_block { + xs = attn_block.forward(&xs, clip.as_ref().unwrap())?; + } + } + level_outputs.push(xs.clone()) + } + level_outputs.reverse(); + + for (i, up_block) in self.up_blocks.iter().enumerate() { + let skip = match &self.effnet_mappers[self.down_blocks.len() + i] { + None => None, + Some(m) => { + let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?; + Some(m.forward(&effnet)?) + } + }; + for (j, block) in up_block.sub_blocks.iter().enumerate() { + let skip = if j == 0 && i > 0 { + Some(&level_outputs[i]) + } else { + None + }; + xs = block.res_block.forward(&xs, skip)?; + xs = block.ts_block.forward(&xs, &r_embed)?; + if let Some(attn_block) = &block.attn_block { + xs = attn_block.forward(&xs, clip.as_ref().unwrap())?; + } + } + if let Some(ln) = &up_block.layer_norm { + xs = xs.apply(ln)? + } + if let Some(conv) = &up_block.conv { + xs = xs.apply(conv)? + } + } + + let ab = xs + .apply(&self.clf_ln)? + .apply(&self.clf_conv)? + .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))? + .chunk(1, 2)?; let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?; (x_in - &ab[0])? / b } diff --git a/candle-transformers/src/models/wuerstchen/mod.rs b/candle-transformers/src/models/wuerstchen/mod.rs index 81755dd1..435bdac2 100644 --- a/candle-transformers/src/models/wuerstchen/mod.rs +++ b/candle-transformers/src/models/wuerstchen/mod.rs @@ -1,3 +1,4 @@ pub mod common; pub mod diffnext; +pub mod paella_vq; pub mod prior; diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs new file mode 100644 index 00000000..6301b7a1 --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs @@ -0,0 +1,111 @@ +#![allow(unused)] +use super::common::{AttnBlock, ResBlock, TimestepBlock}; +use candle::{DType, Module, Result, Tensor, D}; +use candle_nn::VarBuilder; + +#[derive(Debug)] +struct MixingResidualBlock { + norm1: candle_nn::LayerNorm, + depthwise_conv: candle_nn::Conv2d, + norm2: candle_nn::LayerNorm, + channelwise_lin1: candle_nn::Linear, + channelwise_lin2: candle_nn::Linear, + gammas: Vec<f32>, +} + +impl MixingResidualBlock { + pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> { + let cfg = candle_nn::LayerNormConfig { + affine: false, + eps: 1e-6, + remove_mean: true, + }; + let norm1 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?; + let norm2 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?; + let cfg = candle_nn::Conv2dConfig { + groups: inp, + ..Default::default() + }; + let depthwise_conv = candle_nn::conv2d(inp, inp, 3, cfg, vb.pp("depthwise.1"))?; + let channelwise_lin1 = candle_nn::linear(inp, embed_dim, vb.pp("channelwise.0"))?; + let channelwise_lin2 = candle_nn::linear(embed_dim, inp, vb.pp("channelwise.2"))?; + let gammas = vb.get(6, "gammas")?.to_vec1::<f32>()?; + Ok(Self { + norm1, + depthwise_conv, + norm2, + channelwise_lin1, + channelwise_lin2, + gammas, + }) + } +} + +impl Module for MixingResidualBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mods = &self.gammas; + let x_temp = xs + .permute((0, 2, 3, 1))? + .apply(&self.norm1)? + .permute((0, 3, 1, 2))? + .affine(1. + mods[0] as f64, mods[1] as f64)?; + // TODO: Add the ReplicationPad2d + let xs = (xs + x_temp.apply(&self.depthwise_conv)? * mods[2] as f64)?; + let x_temp = xs + .permute((0, 2, 3, 1))? + .apply(&self.norm2)? + .permute((0, 3, 1, 2))? + .affine(1. + mods[3] as f64, mods[4] as f64)?; + let x_temp = x_temp + .permute((0, 2, 3, 1))? + .apply(&self.channelwise_lin1)? + .gelu()? + .apply(&self.channelwise_lin2)? + .permute((0, 3, 1, 2))?; + xs + x_temp * mods[5] as f64 + } +} + +#[derive(Debug)] +struct PaellaVQ { + in_block_conv: candle_nn::Conv2d, + out_block_conv: candle_nn::Conv2d, + down_blocks: Vec<(Option<candle_nn::Conv2d>, MixingResidualBlock)>, + down_blocks_conv: candle_nn::Conv2d, + down_blocks_bn: candle_nn::BatchNorm, + up_blocks_conv: candle_nn::Conv2d, + up_blocks: Vec<(MixingResidualBlock, Option<candle_nn::ConvTranspose2d>)>, +} + +impl PaellaVQ { + pub fn encode(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = candle_nn::ops::pixel_unshuffle(xs, 2)?.apply(&self.in_block_conv)?; + for down_block in self.down_blocks.iter() { + if let Some(conv) = &down_block.0 { + xs = xs.apply(conv)? + } + xs = xs.apply(&down_block.1)? + } + xs.apply(&self.down_blocks_conv)? + .apply(&self.down_blocks_bn) + // TODO: quantizer + } + + pub fn decode(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = xs.apply(&self.up_blocks_conv)?; + for up_block in self.up_blocks.iter() { + xs = xs.apply(&up_block.0)?; + if let Some(conv) = &up_block.1 { + xs = xs.apply(conv)? + } + } + xs.apply(&self.out_block_conv)? + .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, 2)) + } +} + +impl Module for PaellaVQ { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + self.decode(&self.encode(xs)?) + } +} |