summaryrefslogtreecommitdiff
path: root/candle-transformers/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src')
-rw-r--r--candle-transformers/src/models/mod.rs1
-rw-r--r--candle-transformers/src/models/wuerstchen/common.rs126
-rw-r--r--candle-transformers/src/models/wuerstchen/diffnext.rs55
-rw-r--r--candle-transformers/src/models/wuerstchen/mod.rs3
-rw-r--r--candle-transformers/src/models/wuerstchen/prior.rs94
5 files changed, 279 insertions, 0 deletions
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index e2e0bf81..a20254d9 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -9,3 +9,4 @@ pub mod segment_anything;
pub mod stable_diffusion;
pub mod t5;
pub mod whisper;
+pub mod wuerstchen;
diff --git a/candle-transformers/src/models/wuerstchen/common.rs b/candle-transformers/src/models/wuerstchen/common.rs
new file mode 100644
index 00000000..fc731a59
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/common.rs
@@ -0,0 +1,126 @@
+use candle::{Module, Result, Tensor, D};
+use candle_nn::VarBuilder;
+
+// https://github.com/huggingface/diffusers/blob/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py#L22
+#[derive(Debug)]
+pub struct WLayerNorm {
+ inner: candle_nn::LayerNorm,
+}
+
+impl WLayerNorm {
+ pub fn new(size: usize, vb: VarBuilder) -> Result<Self> {
+ let cfg = candle_nn::layer_norm::LayerNormConfig {
+ eps: 1e-6,
+ remove_mean: true,
+ affine: false,
+ };
+ let inner = candle_nn::layer_norm(size, cfg, vb)?;
+ Ok(Self { inner })
+ }
+}
+
+impl Module for WLayerNorm {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.permute((0, 2, 3, 1))?
+ .apply(&self.inner)?
+ .permute((0, 3, 1, 2))
+ }
+}
+
+#[derive(Debug)]
+pub struct TimestepBlock {
+ mapper: candle_nn::Linear,
+}
+
+impl TimestepBlock {
+ pub fn new(c: usize, c_timestep: usize, vb: VarBuilder) -> Result<Self> {
+ let mapper = candle_nn::linear(c_timestep, c * 2, vb.pp("mapper"))?;
+ Ok(Self { mapper })
+ }
+
+ pub fn forward(&self, xs: &Tensor, t: &Tensor) -> Result<Tensor> {
+ let ab = self
+ .mapper
+ .forward(t)?
+ .unsqueeze(2)?
+ .unsqueeze(3)?
+ .chunk(2, 1)?;
+ xs.broadcast_mul(&(&ab[0] + 1.)?)?.broadcast_add(&ab[1])
+ }
+}
+
+#[derive(Debug)]
+pub struct GlobalResponseNorm {
+ gamma: Tensor,
+ beta: Tensor,
+}
+
+impl GlobalResponseNorm {
+ pub fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
+ let gamma = vb.get((1, 1, 1, 1, dim), "gamma")?;
+ let beta = vb.get((1, 1, 1, 1, dim), "beta")?;
+ Ok(Self { gamma, beta })
+ }
+}
+
+impl Module for GlobalResponseNorm {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?;
+ let stand_div_norm =
+ agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?;
+ (xs.broadcast_mul(&stand_div_norm)?
+ .broadcast_mul(&self.gamma)
+ + &self.beta)?
+ + xs
+ }
+}
+
+#[derive(Debug)]
+pub struct ResBlock {
+ depthwise: candle_nn::Conv2d,
+ norm: WLayerNorm,
+ channelwise_lin1: candle_nn::Linear,
+ channelwise_grn: GlobalResponseNorm,
+ channelwise_lin2: candle_nn::Linear,
+}
+
+impl ResBlock {
+ pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> {
+ let cfg = candle_nn::Conv2dConfig {
+ padding: ksize / 2,
+ groups: c,
+ ..Default::default()
+ };
+ let depthwise = candle_nn::conv2d(c + c_skip, c, ksize, cfg, vb.pp("depthwise"))?;
+ let norm = WLayerNorm::new(c, vb.pp("norm"))?;
+ let channelwise_lin1 = candle_nn::linear(c, c * 4, vb.pp("channelwise.0"))?;
+ let channelwise_grn = GlobalResponseNorm::new(c * 4, vb.pp("channelwise.2"))?;
+ let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
+ Ok(Self {
+ depthwise,
+ norm,
+ channelwise_lin1,
+ channelwise_grn,
+ channelwise_lin2,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> {
+ let x_res = xs;
+ let xs = match x_skip {
+ None => xs.clone(),
+ Some(x_skip) => Tensor::cat(&[xs, x_skip], 1)?,
+ };
+ let xs = xs
+ .apply(&self.depthwise)?
+ .apply(&self.norm)?
+ .permute((0, 2, 3, 1))?;
+ let xs = xs
+ .apply(&self.channelwise_lin1)?
+ .gelu()?
+ .apply(&self.channelwise_grn)?
+ .apply(&self.channelwise_lin2)?
+ .permute((0, 3, 1, 2))?;
+ xs + x_res
+ }
+}
diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs
new file mode 100644
index 00000000..82c973a1
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/diffnext.rs
@@ -0,0 +1,55 @@
+#![allow(unused)]
+use super::common::{GlobalResponseNorm, ResBlock, TimestepBlock, WLayerNorm};
+use candle::{DType, Module, Result, Tensor, D};
+use candle_nn::VarBuilder;
+
+#[derive(Debug)]
+pub struct ResBlockStageB {
+ depthwise: candle_nn::Conv2d,
+ norm: WLayerNorm,
+ channelwise_lin1: candle_nn::Linear,
+ channelwise_grn: GlobalResponseNorm,
+ channelwise_lin2: candle_nn::Linear,
+}
+
+impl ResBlockStageB {
+ pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> {
+ let cfg = candle_nn::Conv2dConfig {
+ groups: c,
+ padding: ksize / 2,
+ ..Default::default()
+ };
+ let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?;
+ let norm = WLayerNorm::new(c, vb.pp("norm"))?;
+ let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?;
+ let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?;
+ let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
+ Ok(Self {
+ depthwise,
+ norm,
+ channelwise_lin1,
+ channelwise_grn,
+ channelwise_lin2,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> {
+ let x_res = xs;
+ let xs = xs.apply(&self.depthwise)?.apply(&self.norm)?;
+ let xs = match x_skip {
+ None => xs.clone(),
+ Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?,
+ };
+ let xs = xs
+ .permute((0, 2, 3, 1))?
+ .apply(&self.channelwise_lin1)?
+ .gelu()?
+ .apply(&self.channelwise_grn)?
+ .apply(&self.channelwise_lin2)?
+ .permute((0, 3, 1, 2))?;
+ xs + x_res
+ }
+}
+
+#[derive(Debug)]
+pub struct WDiffNeXt {}
diff --git a/candle-transformers/src/models/wuerstchen/mod.rs b/candle-transformers/src/models/wuerstchen/mod.rs
new file mode 100644
index 00000000..81755dd1
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/mod.rs
@@ -0,0 +1,3 @@
+pub mod common;
+pub mod diffnext;
+pub mod prior;
diff --git a/candle-transformers/src/models/wuerstchen/prior.rs b/candle-transformers/src/models/wuerstchen/prior.rs
new file mode 100644
index 00000000..a4e0300c
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/prior.rs
@@ -0,0 +1,94 @@
+#![allow(unused)]
+use super::common::{ResBlock, TimestepBlock};
+use candle::{DType, Module, Result, Tensor, D};
+use candle_nn::VarBuilder;
+
+#[derive(Debug)]
+struct Block {
+ res_block: ResBlock,
+ ts_block: TimestepBlock,
+ // TODO: attn_block: super::common::AttnBlock,
+}
+
+#[derive(Debug)]
+pub struct WPrior {
+ projection: candle_nn::Conv2d,
+ cond_mapper_lin1: candle_nn::Linear,
+ cond_mapper_lin2: candle_nn::Linear,
+ blocks: Vec<Block>,
+ out_ln: super::common::WLayerNorm,
+ out_conv: candle_nn::Conv2d,
+ c_r: usize,
+}
+
+impl WPrior {
+ pub fn new(
+ c_in: usize,
+ c: usize,
+ c_cond: usize,
+ c_r: usize,
+ depth: usize,
+ _nhead: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
+ let cond_mapper_lin1 = candle_nn::linear(c_cond, c, vb.pp("cond_mapper.0"))?;
+ let cond_mapper_lin2 = candle_nn::linear(c, c, vb.pp("cond_mapper.2"))?;
+ let out_ln = super::common::WLayerNorm::new(c, vb.pp("out.0"))?;
+ let out_conv = candle_nn::conv2d(c, c_in * 2, 1, Default::default(), vb.pp("out.1"))?;
+ let mut blocks = Vec::with_capacity(depth);
+ for index in 0..depth {
+ let res_block = ResBlock::new(c, 0, 3, vb.pp(format!("blocks.{}", 3 * index)))?;
+ let ts_block = TimestepBlock::new(c, c_r, vb.pp(format!("blocks.{}", 3 * index + 1)))?;
+ blocks.push(Block {
+ res_block,
+ ts_block,
+ })
+ }
+ Ok(Self {
+ projection,
+ cond_mapper_lin1,
+ cond_mapper_lin2,
+ blocks,
+ out_ln,
+ out_conv,
+ c_r,
+ })
+ }
+
+ pub fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> {
+ const MAX_POSITIONS: usize = 10000;
+ let r = (r * MAX_POSITIONS as f64)?;
+ let half_dim = self.c_r / 2;
+ let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64;
+ let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)?
+ * -emb)?
+ .exp()?;
+ let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
+ let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?;
+ let emb = if self.c_r % 2 == 1 {
+ emb.pad_with_zeros(D::Minus1, 0, 1)?
+ } else {
+ emb
+ };
+ emb.to_dtype(r.dtype())
+ }
+
+ pub fn forward(&self, xs: &Tensor, r: &Tensor, c: &Tensor) -> Result<Tensor> {
+ let x_in = xs;
+ let mut xs = xs.apply(&self.projection)?;
+ // TODO: leaky relu
+ let c_embed = c
+ .apply(&self.cond_mapper_lin1)?
+ .relu()?
+ .apply(&self.cond_mapper_lin2)?;
+ let r_embed = self.gen_r_embedding(r)?;
+ for block in self.blocks.iter() {
+ xs = block.res_block.forward(&xs, None)?;
+ xs = block.ts_block.forward(&xs, &r_embed)?;
+ // TODO: attn
+ }
+ let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(1, 2)?;
+ (x_in - &ab[0])? / ((&ab[1] - 1.)?.abs()? + 1e-5)
+ }
+}