summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/wuerstchen
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-15 11:14:02 +0200
committerGitHub <noreply@github.com>2023-09-15 10:14:02 +0100
commit2746f2c4bec61d8f2c7947a297c46a74a3f799db (patch)
tree60a9d10a7504063a3a84f7a9ff8f10b3e0258a30 /candle-transformers/src/models/wuerstchen
parent81a36b8713721e2cfd098fad04972bf823ef0d5d (diff)
downloadcandle-2746f2c4bec61d8f2c7947a297c46a74a3f799db.tar.gz
candle-2746f2c4bec61d8f2c7947a297c46a74a3f799db.tar.bz2
candle-2746f2c4bec61d8f2c7947a297c46a74a3f799db.zip
DiffNeXt/unet (#859)
* DiffNeXt/unet * Start adding the vae. * VAE residual block. * VAE forward pass. * Add pixel shuffling. * Actually use pixel shuffling.
Diffstat (limited to 'candle-transformers/src/models/wuerstchen')
-rw-r--r--candle-transformers/src/models/wuerstchen/diffnext.rs77
-rw-r--r--candle-transformers/src/models/wuerstchen/mod.rs1
-rw-r--r--candle-transformers/src/models/wuerstchen/paella_vq.rs111
3 files changed, 180 insertions, 9 deletions
diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs
index 7289a54d..33ca192e 100644
--- a/candle-transformers/src/models/wuerstchen/diffnext.rs
+++ b/candle-transformers/src/models/wuerstchen/diffnext.rs
@@ -75,7 +75,7 @@ struct UpBlock {
#[derive(Debug)]
pub struct WDiffNeXt {
clip_mapper: candle_nn::Linear,
- effnet_mappers: Vec<candle_nn::Conv2d>,
+ effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
seq_norm: candle_nn::LayerNorm,
embedding_conv: candle_nn::Conv2d,
embedding_ln: WLayerNorm,
@@ -84,6 +84,7 @@ pub struct WDiffNeXt {
clf_ln: WLayerNorm,
clf_conv: candle_nn::Conv2d,
c_r: usize,
+ patch_size: usize,
}
impl WDiffNeXt {
@@ -102,6 +103,7 @@ impl WDiffNeXt {
const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
+ // TODO: populate effnet_mappers
let effnet_mappers = vec![];
let cfg = candle_nn::layer_norm::LayerNormConfig {
..Default::default()
@@ -232,6 +234,7 @@ impl WDiffNeXt {
clf_ln,
clf_conv,
c_r,
+ patch_size,
})
}
@@ -273,14 +276,70 @@ impl WDiffNeXt {
};
let x_in = xs;
- // TODO: pixel unshuffle.
- let xs = xs.apply(&self.embedding_conv)?.apply(&self.embedding_ln)?;
- // TODO: down blocks
- let level_outputs = xs.clone();
- // TODO: up blocks
- let xs = level_outputs;
- // TODO: pxel shuffle
- let ab = xs.apply(&self.clf_ln)?.apply(&self.clf_conv)?.chunk(1, 2)?;
+ let mut xs = xs
+ .apply(&|xs: &_| candle_nn::ops::pixel_unshuffle(xs, self.patch_size))?
+ .apply(&self.embedding_conv)?
+ .apply(&self.embedding_ln)?;
+
+ let mut level_outputs = Vec::new();
+ for (i, down_block) in self.down_blocks.iter().enumerate() {
+ if let Some(ln) = &down_block.layer_norm {
+ xs = xs.apply(ln)?
+ }
+ if let Some(conv) = &down_block.conv {
+ xs = xs.apply(conv)?
+ }
+ let skip = match &self.effnet_mappers[i] {
+ None => None,
+ Some(m) => {
+ let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
+ Some(m.forward(&effnet)?)
+ }
+ };
+ for block in down_block.sub_blocks.iter() {
+ xs = block.res_block.forward(&xs, skip.as_ref())?;
+ xs = block.ts_block.forward(&xs, &r_embed)?;
+ if let Some(attn_block) = &block.attn_block {
+ xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
+ }
+ }
+ level_outputs.push(xs.clone())
+ }
+ level_outputs.reverse();
+
+ for (i, up_block) in self.up_blocks.iter().enumerate() {
+ let skip = match &self.effnet_mappers[self.down_blocks.len() + i] {
+ None => None,
+ Some(m) => {
+ let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
+ Some(m.forward(&effnet)?)
+ }
+ };
+ for (j, block) in up_block.sub_blocks.iter().enumerate() {
+ let skip = if j == 0 && i > 0 {
+ Some(&level_outputs[i])
+ } else {
+ None
+ };
+ xs = block.res_block.forward(&xs, skip)?;
+ xs = block.ts_block.forward(&xs, &r_embed)?;
+ if let Some(attn_block) = &block.attn_block {
+ xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
+ }
+ }
+ if let Some(ln) = &up_block.layer_norm {
+ xs = xs.apply(ln)?
+ }
+ if let Some(conv) = &up_block.conv {
+ xs = xs.apply(conv)?
+ }
+ }
+
+ let ab = xs
+ .apply(&self.clf_ln)?
+ .apply(&self.clf_conv)?
+ .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))?
+ .chunk(1, 2)?;
let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?;
(x_in - &ab[0])? / b
}
diff --git a/candle-transformers/src/models/wuerstchen/mod.rs b/candle-transformers/src/models/wuerstchen/mod.rs
index 81755dd1..435bdac2 100644
--- a/candle-transformers/src/models/wuerstchen/mod.rs
+++ b/candle-transformers/src/models/wuerstchen/mod.rs
@@ -1,3 +1,4 @@
pub mod common;
pub mod diffnext;
+pub mod paella_vq;
pub mod prior;
diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs
new file mode 100644
index 00000000..6301b7a1
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs
@@ -0,0 +1,111 @@
+#![allow(unused)]
+use super::common::{AttnBlock, ResBlock, TimestepBlock};
+use candle::{DType, Module, Result, Tensor, D};
+use candle_nn::VarBuilder;
+
+#[derive(Debug)]
+struct MixingResidualBlock {
+ norm1: candle_nn::LayerNorm,
+ depthwise_conv: candle_nn::Conv2d,
+ norm2: candle_nn::LayerNorm,
+ channelwise_lin1: candle_nn::Linear,
+ channelwise_lin2: candle_nn::Linear,
+ gammas: Vec<f32>,
+}
+
+impl MixingResidualBlock {
+ pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
+ let cfg = candle_nn::LayerNormConfig {
+ affine: false,
+ eps: 1e-6,
+ remove_mean: true,
+ };
+ let norm1 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?;
+ let norm2 = candle_nn::layer_norm(inp, cfg, vb.pp("norm1"))?;
+ let cfg = candle_nn::Conv2dConfig {
+ groups: inp,
+ ..Default::default()
+ };
+ let depthwise_conv = candle_nn::conv2d(inp, inp, 3, cfg, vb.pp("depthwise.1"))?;
+ let channelwise_lin1 = candle_nn::linear(inp, embed_dim, vb.pp("channelwise.0"))?;
+ let channelwise_lin2 = candle_nn::linear(embed_dim, inp, vb.pp("channelwise.2"))?;
+ let gammas = vb.get(6, "gammas")?.to_vec1::<f32>()?;
+ Ok(Self {
+ norm1,
+ depthwise_conv,
+ norm2,
+ channelwise_lin1,
+ channelwise_lin2,
+ gammas,
+ })
+ }
+}
+
+impl Module for MixingResidualBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mods = &self.gammas;
+ let x_temp = xs
+ .permute((0, 2, 3, 1))?
+ .apply(&self.norm1)?
+ .permute((0, 3, 1, 2))?
+ .affine(1. + mods[0] as f64, mods[1] as f64)?;
+ // TODO: Add the ReplicationPad2d
+ let xs = (xs + x_temp.apply(&self.depthwise_conv)? * mods[2] as f64)?;
+ let x_temp = xs
+ .permute((0, 2, 3, 1))?
+ .apply(&self.norm2)?
+ .permute((0, 3, 1, 2))?
+ .affine(1. + mods[3] as f64, mods[4] as f64)?;
+ let x_temp = x_temp
+ .permute((0, 2, 3, 1))?
+ .apply(&self.channelwise_lin1)?
+ .gelu()?
+ .apply(&self.channelwise_lin2)?
+ .permute((0, 3, 1, 2))?;
+ xs + x_temp * mods[5] as f64
+ }
+}
+
+#[derive(Debug)]
+struct PaellaVQ {
+ in_block_conv: candle_nn::Conv2d,
+ out_block_conv: candle_nn::Conv2d,
+ down_blocks: Vec<(Option<candle_nn::Conv2d>, MixingResidualBlock)>,
+ down_blocks_conv: candle_nn::Conv2d,
+ down_blocks_bn: candle_nn::BatchNorm,
+ up_blocks_conv: candle_nn::Conv2d,
+ up_blocks: Vec<(MixingResidualBlock, Option<candle_nn::ConvTranspose2d>)>,
+}
+
+impl PaellaVQ {
+ pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = candle_nn::ops::pixel_unshuffle(xs, 2)?.apply(&self.in_block_conv)?;
+ for down_block in self.down_blocks.iter() {
+ if let Some(conv) = &down_block.0 {
+ xs = xs.apply(conv)?
+ }
+ xs = xs.apply(&down_block.1)?
+ }
+ xs.apply(&self.down_blocks_conv)?
+ .apply(&self.down_blocks_bn)
+ // TODO: quantizer
+ }
+
+ pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = xs.apply(&self.up_blocks_conv)?;
+ for up_block in self.up_blocks.iter() {
+ xs = xs.apply(&up_block.0)?;
+ if let Some(conv) = &up_block.1 {
+ xs = xs.apply(conv)?
+ }
+ }
+ xs.apply(&self.out_block_conv)?
+ .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, 2))
+ }
+}
+
+impl Module for PaellaVQ {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ self.decode(&self.encode(xs)?)
+ }
+}