diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-06 18:49:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-06 17:49:43 +0100 |
commit | d34039e35267b3f4de83770f8da4ea31491bcec5 (patch) | |
tree | 6efc4796859a04223d2211303cc12b3a97321fcc /candle-examples/examples | |
parent | 93cfe5642f473889d1df62ccb8f1740f77523dd3 (diff) | |
download | candle-d34039e35267b3f4de83770f8da4ea31491bcec5.tar.gz candle-d34039e35267b3f4de83770f8da4ea31491bcec5.tar.bz2 candle-d34039e35267b3f4de83770f8da4ea31491bcec5.zip |
Add a stable diffusion example (#328)
* Start adding a stable-diffusion example.
* Proper computation of the causal mask.
* Add the chunk operation.
* Work in progress: port the attention module.
* Add some dummy modules for conv2d and group-norm, get the attention module to compile.
* Re-enable the 2d convolution.
* Add the embeddings module.
* Add the resnet module.
* Add the unet blocks.
* Add the unet.
* And add the variational auto-encoder.
* Use the pad function from utils.
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/stable-diffusion/attention.rs | 445 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/clip.rs | 304 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/embeddings.rs | 65 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/main.rs | 30 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/resnet.rs | 129 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/unet_2d.rs | 383 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/unet_2d_blocks.rs | 809 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/utils.rs | 17 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/vae.rs | 378 |
9 files changed, 2560 insertions, 0 deletions
diff --git a/candle-examples/examples/stable-diffusion/attention.rs b/candle-examples/examples/stable-diffusion/attention.rs new file mode 100644 index 00000000..83e7ef34 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/attention.rs @@ -0,0 +1,445 @@ +#![allow(dead_code)] +//! Attention Based Building Blocks +use candle::{IndexOp, Result, Tensor, D}; +use candle_nn as nn; + +#[derive(Debug)] +struct GeGlu { + proj: nn::Linear, +} + +impl GeGlu { + fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> { + let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?; + Ok(Self { proj }) + } +} + +impl GeGlu { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?; + &hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()? + } +} + +/// A feed-forward layer. +#[derive(Debug)] +struct FeedForward { + project_in: GeGlu, + linear: nn::Linear, +} + +impl FeedForward { + // The glu parameter in the python code is unused? + // https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L347 + /// Creates a new feed-forward layer based on some given input dimension, some + /// output dimension, and a multiplier to be used for the intermediary layer. + fn new(vs: nn::VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> { + let inner_dim = dim * mult; + let dim_out = dim_out.unwrap_or(dim); + let vs = vs.pp("net"); + let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?; + let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?; + Ok(Self { project_in, linear }) + } +} + +impl FeedForward { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.project_in.forward(xs)?; + self.linear.forward(&xs) + } +} + +#[derive(Debug)] +struct CrossAttention { + to_q: nn::Linear, + to_k: nn::Linear, + to_v: nn::Linear, + to_out: nn::Linear, + heads: usize, + scale: f64, + slice_size: Option<usize>, +} + +impl CrossAttention { + // Defaults should be heads = 8, dim_head = 64, context_dim = None + fn new( + vs: nn::VarBuilder, + query_dim: usize, + context_dim: Option<usize>, + heads: usize, + dim_head: usize, + slice_size: Option<usize>, + ) -> Result<Self> { + let inner_dim = dim_head * heads; + let context_dim = context_dim.unwrap_or(query_dim); + let scale = 1.0 / f64::sqrt(dim_head as f64); + let to_q = nn::linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?; + let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?; + let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?; + let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?; + Ok(Self { + to_q, + to_k, + to_v, + to_out, + heads, + scale, + slice_size, + }) + } + + fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> { + let (batch_size, seq_len, dim) = xs.dims3()?; + xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))? + .transpose(1, 2)? + .reshape((batch_size * self.heads, seq_len, dim / self.heads)) + } + + fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> { + let (batch_size, seq_len, dim) = xs.dims3()?; + xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))? + .transpose(1, 2)? + .reshape((batch_size / self.heads, seq_len, dim * self.heads)) + } + + fn sliced_attention( + &self, + query: &Tensor, + key: &Tensor, + value: &Tensor, + slice_size: usize, + ) -> Result<Tensor> { + let batch_size_attention = query.dim(0)?; + let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size); + + for i in 0..batch_size_attention / slice_size { + let start_idx = i * slice_size; + let end_idx = (i + 1) * slice_size; + + let xs = query + .i(start_idx..end_idx)? + .matmul(&(key.i(start_idx..end_idx)?.t()? * self.scale)?)?; + let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?; + hidden_states.push(xs) + } + let hidden_states = Tensor::stack(&hidden_states, 0)?; + self.reshape_batch_dim_to_heads(&hidden_states) + } + + fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> { + let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?; + let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?; + self.reshape_batch_dim_to_heads(&xs) + } + + fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> { + let query = self.to_q.forward(xs)?; + let context = context.unwrap_or(xs); + let key = self.to_k.forward(context)?; + let value = self.to_v.forward(context)?; + let query = self.reshape_heads_to_batch_dim(&query)?; + let key = self.reshape_heads_to_batch_dim(&key)?; + let value = self.reshape_heads_to_batch_dim(&value)?; + let xs = match self.slice_size { + None => self.attention(&query, &key, &value)?, + Some(slice_size) => { + if query.dim(0)? / slice_size <= 1 { + self.attention(&query, &key, &value)? + } else { + self.sliced_attention(&query, &key, &value, slice_size)? + } + } + }; + self.to_out.forward(&xs) + } +} + +/// A basic Transformer block. +#[derive(Debug)] +struct BasicTransformerBlock { + attn1: CrossAttention, + ff: FeedForward, + attn2: CrossAttention, + norm1: nn::LayerNorm, + norm2: nn::LayerNorm, + norm3: nn::LayerNorm, +} + +impl BasicTransformerBlock { + fn new( + vs: nn::VarBuilder, + dim: usize, + n_heads: usize, + d_head: usize, + context_dim: Option<usize>, + sliced_attention_size: Option<usize>, + ) -> Result<Self> { + let attn1 = CrossAttention::new( + vs.pp("attn1"), + dim, + None, + n_heads, + d_head, + sliced_attention_size, + )?; + let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?; + let attn2 = CrossAttention::new( + vs.pp("attn2"), + dim, + context_dim, + n_heads, + d_head, + sliced_attention_size, + )?; + let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?; + let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?; + let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?; + Ok(Self { + attn1, + ff, + attn2, + norm1, + norm2, + norm3, + }) + } + + fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> { + let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?; + let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?; + self.ff.forward(&self.norm3.forward(&xs)?)? + xs + } +} + +#[derive(Debug, Clone, Copy)] +pub struct SpatialTransformerConfig { + pub depth: usize, + pub num_groups: usize, + pub context_dim: Option<usize>, + pub sliced_attention_size: Option<usize>, + pub use_linear_projection: bool, +} + +impl Default for SpatialTransformerConfig { + fn default() -> Self { + Self { + depth: 1, + num_groups: 32, + context_dim: None, + sliced_attention_size: None, + use_linear_projection: false, + } + } +} + +#[derive(Debug)] +enum Proj { + Conv2d(nn::Conv2d), + Linear(nn::Linear), +} + +// Aka Transformer2DModel +#[derive(Debug)] +pub struct SpatialTransformer { + norm: nn::GroupNorm, + proj_in: Proj, + transformer_blocks: Vec<BasicTransformerBlock>, + proj_out: Proj, + pub config: SpatialTransformerConfig, +} + +impl SpatialTransformer { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + n_heads: usize, + d_head: usize, + config: SpatialTransformerConfig, + ) -> Result<Self> { + let inner_dim = n_heads * d_head; + let norm = nn::group_norm(config.num_groups, in_channels, 1e-6, vs.pp("norm"))?; + let proj_in = if config.use_linear_projection { + Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_in"))?) + } else { + Proj::Conv2d(nn::conv2d( + in_channels, + inner_dim, + 1, + Default::default(), + vs.pp("proj_in"), + )?) + }; + let mut transformer_blocks = vec![]; + let vs_tb = vs.pp("transformer_blocks"); + for index in 0..config.depth { + let tb = BasicTransformerBlock::new( + vs_tb.pp(&index.to_string()), + inner_dim, + n_heads, + d_head, + config.context_dim, + config.sliced_attention_size, + )?; + transformer_blocks.push(tb) + } + let proj_out = if config.use_linear_projection { + Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_out"))?) + } else { + Proj::Conv2d(nn::conv2d( + inner_dim, + in_channels, + 1, + Default::default(), + vs.pp("proj_out"), + )?) + }; + Ok(Self { + norm, + proj_in, + transformer_blocks, + proj_out, + config, + }) + } + + pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> { + let (batch, _channel, height, weight) = xs.dims4()?; + let residual = xs; + let xs = self.norm.forward(xs)?; + let (inner_dim, xs) = match &self.proj_in { + Proj::Conv2d(p) => { + let xs = p.forward(&xs)?; + let inner_dim = xs.dim(1)?; + let xs = xs + .transpose(1, 2)? + .t()? + .reshape((batch, height * weight, inner_dim))?; + (inner_dim, xs) + } + Proj::Linear(p) => { + let inner_dim = xs.dim(1)?; + let xs = xs + .transpose(1, 2)? + .t()? + .reshape((batch, height * weight, inner_dim))?; + (inner_dim, p.forward(&xs)?) + } + }; + let mut xs = xs; + for block in self.transformer_blocks.iter() { + xs = block.forward(&xs, context)? + } + let xs = match &self.proj_out { + Proj::Conv2d(p) => p.forward( + &xs.reshape((batch, height, weight, inner_dim))? + .t()? + .transpose(1, 2)?, + )?, + Proj::Linear(p) => p + .forward(&xs)? + .reshape((batch, height, weight, inner_dim))? + .t()? + .transpose(1, 2)?, + }; + xs + residual + } +} + +/// Configuration for an attention block. +#[derive(Debug, Clone, Copy)] +pub struct AttentionBlockConfig { + pub num_head_channels: Option<usize>, + pub num_groups: usize, + pub rescale_output_factor: f64, + pub eps: f64, +} + +impl Default for AttentionBlockConfig { + fn default() -> Self { + Self { + num_head_channels: None, + num_groups: 32, + rescale_output_factor: 1., + eps: 1e-5, + } + } +} + +#[derive(Debug)] +pub struct AttentionBlock { + group_norm: nn::GroupNorm, + query: nn::Linear, + key: nn::Linear, + value: nn::Linear, + proj_attn: nn::Linear, + channels: usize, + num_heads: usize, + config: AttentionBlockConfig, +} + +impl AttentionBlock { + pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result<Self> { + let num_head_channels = config.num_head_channels.unwrap_or(channels); + let num_heads = channels / num_head_channels; + let group_norm = + nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?; + let query = nn::linear(channels, channels, vs.pp("query"))?; + let key = nn::linear(channels, channels, vs.pp("key"))?; + let value = nn::linear(channels, channels, vs.pp("value"))?; + let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?; + Ok(Self { + group_norm, + query, + key, + value, + proj_attn, + channels, + num_heads, + config, + }) + } + + fn transpose_for_scores(&self, xs: Tensor) -> Result<Tensor> { + let (batch, t, h_times_d) = xs.dims3()?; + xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))? + .transpose(1, 2) + } +} + +impl AttentionBlock { + pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let residual = xs; + let (batch, channel, height, width) = xs.dims4()?; + let xs = self + .group_norm + .forward(xs)? + .reshape((batch, channel, height * width))? + .transpose(1, 2)?; + + let query_proj = self.query.forward(&xs)?; + let key_proj = self.key.forward(&xs)?; + let value_proj = self.value.forward(&xs)?; + + let query_states = self.transpose_for_scores(query_proj)?; + let key_states = self.transpose_for_scores(key_proj)?; + let value_states = self.transpose_for_scores(value_proj)?; + + let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25); + let attention_scores = + // TODO: Check that this needs two multiplication by `scale`. + (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?; + let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?; + + let xs = attention_probs.matmul(&value_states)?; + let xs = xs.transpose(1, 2)?.contiguous()?; + let xs = xs.flatten_from(D::Minus2)?; + let xs = self + .proj_attn + .forward(&xs)? + .t()? + .reshape((batch, channel, height, width))?; + (xs + residual)? / self.config.rescale_output_factor + } +} diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs new file mode 100644 index 00000000..86313c52 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/clip.rs @@ -0,0 +1,304 @@ +#![allow(dead_code)] +//! Contrastive Language-Image Pre-Training +//! +//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on +//! pairs of images with related texts. +//! +//! https://github.com/openai/CLIP +use candle::{Device, Result, Tensor, D}; + +#[derive(Debug, Clone, Copy)] +pub enum Activation { + QuickGelu, + Gelu, +} + +impl Activation { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + match self { + Activation::QuickGelu => xs * crate::utils::sigmoid(&(xs * 1.702f64)?)?, + Activation::Gelu => xs.gelu(), + } + } +} + +#[derive(Debug, Clone)] +pub struct Config { + vocab_size: usize, + embed_dim: usize, // aka config.hidden_size + activation: Activation, // aka config.hidden_act + intermediate_size: usize, + max_position_embeddings: usize, + // The character to use for padding, use EOS when not set. + pad_with: Option<String>, + num_hidden_layers: usize, + num_attention_heads: usize, + #[allow(dead_code)] + projection_dim: usize, +} + +impl Config { + // The config details can be found in the "text_config" section of this json file: + // https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json + pub fn v1_5() -> Self { + Self { + vocab_size: 49408, + embed_dim: 768, + intermediate_size: 3072, + max_position_embeddings: 77, + pad_with: None, + num_hidden_layers: 12, + num_attention_heads: 12, + projection_dim: 768, + activation: Activation::QuickGelu, + } + } + + // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/text_encoder/config.json + pub fn v2_1() -> Self { + Self { + vocab_size: 49408, + embed_dim: 1024, + intermediate_size: 4096, + max_position_embeddings: 77, + pad_with: Some("!".to_string()), + num_hidden_layers: 23, + num_attention_heads: 16, + projection_dim: 512, + activation: Activation::Gelu, + } + } +} + +// CLIP Text Model +// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py +#[derive(Debug)] +struct ClipTextEmbeddings { + token_embedding: candle_nn::Embedding, + position_embedding: candle_nn::Embedding, + position_ids: Tensor, +} + +impl ClipTextEmbeddings { + fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> { + let token_embedding = + candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?; + let position_embedding = candle_nn::embedding( + c.max_position_embeddings, + c.embed_dim, + vs.pp("position_embedding"), + )?; + let position_ids = + Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(1)?; + Ok(ClipTextEmbeddings { + token_embedding, + position_embedding, + position_ids, + }) + } +} + +impl ClipTextEmbeddings { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let token_embedding = self.token_embedding.forward(xs)?; + let position_embedding = self.position_embedding.forward(&self.position_ids)?; + token_embedding + position_embedding + } +} + +#[derive(Debug)] +struct ClipAttention { + k_proj: candle_nn::Linear, + v_proj: candle_nn::Linear, + q_proj: candle_nn::Linear, + out_proj: candle_nn::Linear, + head_dim: usize, + scale: f64, + num_attention_heads: usize, +} + +impl ClipAttention { + fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> { + let embed_dim = c.embed_dim; + let num_attention_heads = c.num_attention_heads; + let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?; + let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?; + let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?; + let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?; + let head_dim = embed_dim / num_attention_heads; + let scale = (head_dim as f64).powf(-0.5); + Ok(ClipAttention { + k_proj, + v_proj, + q_proj, + out_proj, + head_dim, + scale, + num_attention_heads, + }) + } + + fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> { + xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous() + } + + fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> { + let (bsz, seq_len, embed_dim) = xs.dims3()?; + let query_states = (self.q_proj.forward(xs)? * self.scale)?; + let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim); + let query_states = self + .shape(&query_states, seq_len, bsz)? + .reshape(proj_shape)?; + let key_states = self + .shape(&self.k_proj.forward(xs)?, seq_len, bsz)? + .reshape(proj_shape)?; + let value_states = self + .shape(&self.v_proj.forward(xs)?, seq_len, bsz)? + .reshape(proj_shape)?; + let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?; + + let src_len = key_states.dim(1)?; + let attn_weights = + (attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))? + + causal_attention_mask)?; + let attn_weights = + attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?; + let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; + + let attn_output = attn_weights.matmul(&value_states)?; + let attn_output = attn_output + .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))? + .transpose(1, 2)? + .reshape((bsz, seq_len, embed_dim))?; + self.out_proj.forward(&attn_output) + } +} + +#[derive(Debug)] +struct ClipMlp { + fc1: candle_nn::Linear, + fc2: candle_nn::Linear, + activation: Activation, +} + +impl ClipMlp { + fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> { + let fc1 = candle_nn::linear(c.embed_dim, c.intermediate_size, vs.pp("fc1"))?; + let fc2 = candle_nn::linear(c.intermediate_size, c.embed_dim, vs.pp("fc2"))?; + Ok(ClipMlp { + fc1, + fc2, + activation: c.activation, + }) + } +} + +impl ClipMlp { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.fc1.forward(xs)?; + self.fc2.forward(&self.activation.forward(&xs)?) + } +} + +#[derive(Debug)] +struct ClipEncoderLayer { + self_attn: ClipAttention, + layer_norm1: candle_nn::LayerNorm, + mlp: ClipMlp, + layer_norm2: candle_nn::LayerNorm, +} + +impl ClipEncoderLayer { + fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> { + let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?; + let layer_norm1 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm1"))?; + let mlp = ClipMlp::new(vs.pp("mlp"), c)?; + let layer_norm2 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm2"))?; + Ok(ClipEncoderLayer { + self_attn, + layer_norm1, + mlp, + layer_norm2, + }) + } + + fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> { + let residual = xs; + let xs = self.layer_norm1.forward(xs)?; + let xs = self.self_attn.forward(&xs, causal_attention_mask)?; + let xs = (xs + residual)?; + + let residual = &xs; + let xs = self.layer_norm2.forward(&xs)?; + let xs = self.mlp.forward(&xs)?; + xs + residual + } +} + +#[derive(Debug)] +struct ClipEncoder { + layers: Vec<ClipEncoderLayer>, +} + +impl ClipEncoder { + fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> { + let vs = vs.pp("layers"); + let mut layers: Vec<ClipEncoderLayer> = Vec::new(); + for index in 0..c.num_hidden_layers { + let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?; + layers.push(layer) + } + Ok(ClipEncoder { layers }) + } + + fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> { + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs, causal_attention_mask)? + } + Ok(xs) + } +} + +/// A CLIP transformer based model. +#[derive(Debug)] +pub struct ClipTextTransformer { + embeddings: ClipTextEmbeddings, + encoder: ClipEncoder, + final_layer_norm: candle_nn::LayerNorm, +} + +impl ClipTextTransformer { + pub fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> { + let vs = vs.pp("text_model"); + let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?; + let encoder = ClipEncoder::new(vs.pp("encoder"), c)?; + let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?; + Ok(ClipTextTransformer { + embeddings, + encoder, + final_layer_norm, + }) + } + + // https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678 + fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> { + let mask: Vec<_> = (0..seq_len) + .flat_map(|i| (0..seq_len).map(move |j| u8::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?; + mask.broadcast_as((bsz, seq_len, seq_len)) + } +} + +impl ClipTextTransformer { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let (bsz, seq_len) = xs.dims2()?; + let xs = self.embeddings.forward(xs)?; + let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?; + let xs = self.encoder.forward(&xs, &causal_attention_mask)?; + self.final_layer_norm.forward(&xs) + } +} diff --git a/candle-examples/examples/stable-diffusion/embeddings.rs b/candle-examples/examples/stable-diffusion/embeddings.rs new file mode 100644 index 00000000..f8a4f351 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/embeddings.rs @@ -0,0 +1,65 @@ +#![allow(dead_code)] +use candle::{Result, Tensor, D}; +use candle_nn as nn; + +#[derive(Debug)] +pub struct TimestepEmbedding { + linear_1: nn::Linear, + linear_2: nn::Linear, +} + +impl TimestepEmbedding { + // act_fn: "silu" + pub fn new(vs: nn::VarBuilder, channel: usize, time_embed_dim: usize) -> Result<Self> { + let linear_1 = nn::linear(channel, time_embed_dim, vs.pp("linear_1"))?; + let linear_2 = nn::linear(time_embed_dim, time_embed_dim, vs.pp("linear_2"))?; + Ok(Self { linear_1, linear_2 }) + } +} + +impl TimestepEmbedding { + pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?; + self.linear_2.forward(&xs) + } +} + +#[derive(Debug)] +pub struct Timesteps { + num_channels: usize, + flip_sin_to_cos: bool, + downscale_freq_shift: f64, +} + +impl Timesteps { + pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self { + Self { + num_channels, + flip_sin_to_cos, + downscale_freq_shift, + } + } +} + +impl Timesteps { + pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let half_dim = (self.num_channels / 2) as u32; + let exponent = + (Tensor::arange(0, half_dim, xs.device())?.to_dtype(xs.dtype())? * -f64::ln(10000.))?; + let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?; + let emb = exponent.exp()?; + // emb = timesteps[:, None].float() * emb[None, :] + let emb = (xs.unsqueeze(D::Minus1)? * emb.unsqueeze(0)?)?; + let (cos, sin) = (emb.cos()?, emb.sin()?); + let emb = if self.flip_sin_to_cos { + Tensor::cat(&[&cos, &sin], D::Minus1)? + } else { + Tensor::cat(&[&sin, &cos], D::Minus1)? + }; + if self.num_channels % 2 == 1 { + crate::utils::pad(&emb) // ([0, 1, 0, 0], 'constant', None) + } else { + Ok(emb) + } + } +} diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs new file mode 100644 index 00000000..31848f38 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -0,0 +1,30 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +mod attention; +mod clip; +mod embeddings; +mod resnet; +mod unet_2d; +mod unet_2d_blocks; +mod utils; +mod vae; + +use anyhow::Result; +use clap::Parser; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(long)] + prompt: String, +} + +fn main() -> Result<()> { + let _args = Args::parse(); + Ok(()) +} diff --git a/candle-examples/examples/stable-diffusion/resnet.rs b/candle-examples/examples/stable-diffusion/resnet.rs new file mode 100644 index 00000000..b6696083 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/resnet.rs @@ -0,0 +1,129 @@ +#![allow(dead_code)] +//! ResNet Building Blocks +//! +//! Some Residual Network blocks used in UNet models. +//! +//! Denoising Diffusion Implicit Models, K. He and al, 2015. +//! https://arxiv.org/abs/1512.03385 +use candle::{Result, Tensor, D}; +use candle_nn as nn; + +/// Configuration for a ResNet block. +#[derive(Debug, Clone, Copy)] +pub struct ResnetBlock2DConfig { + /// The number of output channels, defaults to the number of input channels. + pub out_channels: Option<usize>, + pub temb_channels: Option<usize>, + /// The number of groups to use in group normalization. + pub groups: usize, + pub groups_out: Option<usize>, + /// The epsilon to be used in the group normalization operations. + pub eps: f64, + /// Whether to use a 2D convolution in the skip connection. When using None, + /// such a convolution is used if the number of input channels is different from + /// the number of output channels. + pub use_in_shortcut: Option<bool>, + // non_linearity: silu + /// The final output is scaled by dividing by this value. + pub output_scale_factor: f64, +} + +impl Default for ResnetBlock2DConfig { + fn default() -> Self { + Self { + out_channels: None, + temb_channels: Some(512), + groups: 32, + groups_out: None, + eps: 1e-6, + use_in_shortcut: None, + output_scale_factor: 1., + } + } +} + +#[derive(Debug)] +pub struct ResnetBlock2D { + norm1: nn::GroupNorm, + conv1: nn::Conv2d, + norm2: nn::GroupNorm, + conv2: nn::Conv2d, + time_emb_proj: Option<nn::Linear>, + conv_shortcut: Option<nn::Conv2d>, + config: ResnetBlock2DConfig, +} + +impl ResnetBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + config: ResnetBlock2DConfig, + ) -> Result<Self> { + let out_channels = config.out_channels.unwrap_or(in_channels); + let conv_cfg = nn::Conv2dConfig { + stride: 1, + padding: 1, + }; + let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?; + let conv1 = nn::conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?; + let groups_out = config.groups_out.unwrap_or(config.groups); + let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?; + let conv2 = nn::conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?; + let use_in_shortcut = config + .use_in_shortcut + .unwrap_or(in_channels != out_channels); + let conv_shortcut = if use_in_shortcut { + let conv_cfg = nn::Conv2dConfig { + stride: 1, + padding: 0, + }; + Some(nn::conv2d( + in_channels, + out_channels, + 1, + conv_cfg, + vs.pp("conv_shortcut"), + )?) + } else { + None + }; + let time_emb_proj = match config.temb_channels { + None => None, + Some(temb_channels) => Some(nn::linear( + temb_channels, + out_channels, + vs.pp("time_emb_proj"), + )?), + }; + Ok(Self { + norm1, + conv1, + norm2, + conv2, + time_emb_proj, + config, + conv_shortcut, + }) + } + + pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> { + let shortcut_xs = match &self.conv_shortcut { + Some(conv_shortcut) => conv_shortcut.forward(xs)?, + None => xs.clone(), + }; + let xs = self.norm1.forward(xs)?; + let xs = self.conv1.forward(&nn::ops::silu(&xs)?)?; + let xs = match (temb, &self.time_emb_proj) { + (Some(temb), Some(time_emb_proj)) => time_emb_proj + .forward(&nn::ops::silu(temb)?)? + .unsqueeze(D::Minus1)? + .unsqueeze(D::Minus1)? + .add(&xs)?, + _ => xs, + }; + let xs = self + .conv2 + .forward(&nn::ops::silu(&self.norm2.forward(&xs)?)?)?; + (shortcut_xs + xs)? / self.config.output_scale_factor + } +} diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs new file mode 100644 index 00000000..8ebd1876 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/unet_2d.rs @@ -0,0 +1,383 @@ +#![allow(dead_code)] +//! 2D UNet Denoising Models +//! +//! The 2D Unet models take as input a noisy sample and the current diffusion +//! timestep and return a denoised version of the input. +use crate::embeddings::{TimestepEmbedding, Timesteps}; +use crate::unet_2d_blocks::*; +use candle::{DType, Result, Tensor}; +use candle_nn as nn; + +#[derive(Debug, Clone, Copy)] +pub struct BlockConfig { + pub out_channels: usize, + pub use_cross_attn: bool, + pub attention_head_dim: usize, +} + +#[derive(Debug, Clone)] +pub struct UNet2DConditionModelConfig { + pub center_input_sample: bool, + pub flip_sin_to_cos: bool, + pub freq_shift: f64, + pub blocks: Vec<BlockConfig>, + pub layers_per_block: usize, + pub downsample_padding: usize, + pub mid_block_scale_factor: f64, + pub norm_num_groups: usize, + pub norm_eps: f64, + pub cross_attention_dim: usize, + pub sliced_attention_size: Option<usize>, + pub use_linear_projection: bool, +} + +impl Default for UNet2DConditionModelConfig { + fn default() -> Self { + Self { + center_input_sample: false, + flip_sin_to_cos: true, + freq_shift: 0., + blocks: vec![ + BlockConfig { + out_channels: 320, + use_cross_attn: true, + attention_head_dim: 8, + }, + BlockConfig { + out_channels: 640, + use_cross_attn: true, + attention_head_dim: 8, + }, + BlockConfig { + out_channels: 1280, + use_cross_attn: true, + attention_head_dim: 8, + }, + BlockConfig { + out_channels: 1280, + use_cross_attn: false, + attention_head_dim: 8, + }, + ], + layers_per_block: 2, + downsample_padding: 1, + mid_block_scale_factor: 1., + norm_num_groups: 32, + norm_eps: 1e-5, + cross_attention_dim: 1280, + sliced_attention_size: None, + use_linear_projection: false, + } + } +} + +#[derive(Debug)] +pub(crate) enum UNetDownBlock { + Basic(DownBlock2D), + CrossAttn(CrossAttnDownBlock2D), +} + +#[derive(Debug)] +enum UNetUpBlock { + Basic(UpBlock2D), + CrossAttn(CrossAttnUpBlock2D), +} + +#[derive(Debug)] +pub struct UNet2DConditionModel { + conv_in: nn::Conv2d, + time_proj: Timesteps, + time_embedding: TimestepEmbedding, + down_blocks: Vec<UNetDownBlock>, + mid_block: UNetMidBlock2DCrossAttn, + up_blocks: Vec<UNetUpBlock>, + conv_norm_out: nn::GroupNorm, + conv_out: nn::Conv2d, + config: UNet2DConditionModelConfig, +} + +impl UNet2DConditionModel { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + config: UNet2DConditionModelConfig, + ) -> Result<Self> { + let n_blocks = config.blocks.len(); + let b_channels = config.blocks[0].out_channels; + let bl_channels = config.blocks.last().unwrap().out_channels; + let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim; + let time_embed_dim = b_channels * 4; + let conv_cfg = nn::Conv2dConfig { + stride: 1, + padding: 1, + }; + let conv_in = nn::conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?; + + let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift); + let time_embedding = + TimestepEmbedding::new(vs.pp("time_embedding"), b_channels, time_embed_dim)?; + + let vs_db = vs.pp("down_blocks"); + let down_blocks = (0..n_blocks) + .map(|i| { + let BlockConfig { + out_channels, + use_cross_attn, + attention_head_dim, + } = config.blocks[i]; + + // Enable automatic attention slicing if the config sliced_attention_size is set to 0. + let sliced_attention_size = match config.sliced_attention_size { + Some(0) => Some(attention_head_dim / 2), + _ => config.sliced_attention_size, + }; + + let in_channels = if i > 0 { + config.blocks[i - 1].out_channels + } else { + b_channels + }; + let db_cfg = DownBlock2DConfig { + num_layers: config.layers_per_block, + resnet_eps: config.norm_eps, + resnet_groups: config.norm_num_groups, + add_downsample: i < n_blocks - 1, + downsample_padding: config.downsample_padding, + ..Default::default() + }; + if use_cross_attn { + let config = CrossAttnDownBlock2DConfig { + downblock: db_cfg, + attn_num_head_channels: attention_head_dim, + cross_attention_dim: config.cross_attention_dim, + sliced_attention_size, + use_linear_projection: config.use_linear_projection, + }; + let block = CrossAttnDownBlock2D::new( + vs_db.pp(&i.to_string()), + in_channels, + out_channels, + Some(time_embed_dim), + config, + )?; + Ok(UNetDownBlock::CrossAttn(block)) + } else { + let block = DownBlock2D::new( + vs_db.pp(&i.to_string()), + in_channels, + out_channels, + Some(time_embed_dim), + db_cfg, + )?; + Ok(UNetDownBlock::Basic(block)) + } + }) + .collect::<Result<Vec<_>>>()?; + + let mid_cfg = UNetMidBlock2DCrossAttnConfig { + resnet_eps: config.norm_eps, + output_scale_factor: config.mid_block_scale_factor, + cross_attn_dim: config.cross_attention_dim, + attn_num_head_channels: bl_attention_head_dim, + resnet_groups: Some(config.norm_num_groups), + use_linear_projection: config.use_linear_projection, + ..Default::default() + }; + let mid_block = UNetMidBlock2DCrossAttn::new( + vs.pp("mid_block"), + bl_channels, + Some(time_embed_dim), + mid_cfg, + )?; + + let vs_ub = vs.pp("up_blocks"); + let up_blocks = (0..n_blocks) + .map(|i| { + let BlockConfig { + out_channels, + use_cross_attn, + attention_head_dim, + } = config.blocks[n_blocks - 1 - i]; + + // Enable automatic attention slicing if the config sliced_attention_size is set to 0. + let sliced_attention_size = match config.sliced_attention_size { + Some(0) => Some(attention_head_dim / 2), + _ => config.sliced_attention_size, + }; + + let prev_out_channels = if i > 0 { + config.blocks[n_blocks - i].out_channels + } else { + bl_channels + }; + let in_channels = { + let index = if i == n_blocks - 1 { + 0 + } else { + n_blocks - i - 2 + }; + config.blocks[index].out_channels + }; + let ub_cfg = UpBlock2DConfig { + num_layers: config.layers_per_block + 1, + resnet_eps: config.norm_eps, + resnet_groups: config.norm_num_groups, + add_upsample: i < n_blocks - 1, + ..Default::default() + }; + if use_cross_attn { + let config = CrossAttnUpBlock2DConfig { + upblock: ub_cfg, + attn_num_head_channels: attention_head_dim, + cross_attention_dim: config.cross_attention_dim, + sliced_attention_size, + use_linear_projection: config.use_linear_projection, + }; + let block = CrossAttnUpBlock2D::new( + vs_ub.pp(&i.to_string()), + in_channels, + prev_out_channels, + out_channels, + Some(time_embed_dim), + config, + )?; + Ok(UNetUpBlock::CrossAttn(block)) + } else { + let block = UpBlock2D::new( + vs_ub.pp(&i.to_string()), + in_channels, + prev_out_channels, + out_channels, + Some(time_embed_dim), + ub_cfg, + )?; + Ok(UNetUpBlock::Basic(block)) + } + }) + .collect::<Result<Vec<_>>>()?; + + let conv_norm_out = nn::group_norm( + config.norm_num_groups, + b_channels, + config.norm_eps, + vs.pp("conv_norm_out"), + )?; + let conv_out = nn::conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?; + Ok(Self { + conv_in, + time_proj, + time_embedding, + down_blocks, + mid_block, + up_blocks, + conv_norm_out, + conv_out, + config, + }) + } +} + +impl UNet2DConditionModel { + pub fn forward( + &self, + xs: &Tensor, + timestep: f64, + encoder_hidden_states: &Tensor, + ) -> Result<Tensor> { + self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None) + } + + pub fn forward_with_additional_residuals( + &self, + xs: &Tensor, + timestep: f64, + encoder_hidden_states: &Tensor, + down_block_additional_residuals: Option<&[Tensor]>, + mid_block_additional_residual: Option<&Tensor>, + ) -> Result<Tensor> { + let (bsize, _channels, height, width) = xs.dims4()?; + let device = xs.device(); + let n_blocks = self.config.blocks.len(); + let num_upsamplers = n_blocks - 1; + let default_overall_up_factor = 2usize.pow(num_upsamplers as u32); + let forward_upsample_size = + height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0; + // 0. center input if necessary + let xs = if self.config.center_input_sample { + ((xs * 2.0)? - 1.0)? + } else { + xs.clone() + }; + // 1. time + let emb = (Tensor::ones(bsize, DType::F32, device)? * timestep)?; + let emb = self.time_proj.forward(&emb)?; + let emb = self.time_embedding.forward(&emb)?; + // 2. pre-process + let xs = self.conv_in.forward(&xs)?; + // 3. down + let mut down_block_res_xs = vec![xs.clone()]; + let mut xs = xs; + for down_block in self.down_blocks.iter() { + let (_xs, res_xs) = match down_block { + UNetDownBlock::Basic(b) => b.forward(&xs, Some(&emb))?, + UNetDownBlock::CrossAttn(b) => { + b.forward(&xs, Some(&emb), Some(encoder_hidden_states))? + } + }; + down_block_res_xs.extend(res_xs); + xs = _xs; + } + + let new_down_block_res_xs = + if let Some(down_block_additional_residuals) = down_block_additional_residuals { + let mut v = vec![]; + // A previous version of this code had a bug because of the addition being made + // in place via += hence modifying the input of the mid block. + for (i, residuals) in down_block_additional_residuals.iter().enumerate() { + v.push((&down_block_res_xs[i] + residuals)?) + } + v + } else { + down_block_res_xs + }; + let mut down_block_res_xs = new_down_block_res_xs; + + // 4. mid + let xs = self + .mid_block + .forward(&xs, Some(&emb), Some(encoder_hidden_states))?; + let xs = match mid_block_additional_residual { + None => xs, + Some(m) => (m + xs)?, + }; + // 5. up + let mut xs = xs; + let mut upsample_size = None; + for (i, up_block) in self.up_blocks.iter().enumerate() { + let n_resnets = match up_block { + UNetUpBlock::Basic(b) => b.resnets.len(), + UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(), + }; + let res_xs = down_block_res_xs.split_off(down_block_res_xs.len() - n_resnets); + if i < n_blocks - 1 && forward_upsample_size { + let (_, _, h, w) = down_block_res_xs.last().unwrap().dims4()?; + upsample_size = Some((h, w)) + } + xs = match up_block { + UNetUpBlock::Basic(b) => b.forward(&xs, &res_xs, Some(&emb), upsample_size)?, + UNetUpBlock::CrossAttn(b) => b.forward( + &xs, + &res_xs, + Some(&emb), + upsample_size, + Some(encoder_hidden_states), + )?, + }; + } + // 6. post-process + let xs = self.conv_norm_out.forward(&xs)?; + let xs = nn::ops::silu(&xs)?; + self.conv_out.forward(&xs) + } +} diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs new file mode 100644 index 00000000..8dd6cf26 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs @@ -0,0 +1,809 @@ +#![allow(dead_code)] +//! 2D UNet Building Blocks +//! +use crate::attention::{ + AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig, +}; +use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig}; +use candle::{Result, Tensor}; +use candle_nn as nn; + +#[derive(Debug)] +struct Downsample2D { + conv: Option<nn::Conv2d>, + padding: usize, +} + +impl Downsample2D { + fn new( + vs: nn::VarBuilder, + in_channels: usize, + use_conv: bool, + out_channels: usize, + padding: usize, + ) -> Result<Self> { + let conv = if use_conv { + let config = nn::Conv2dConfig { stride: 2, padding }; + let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?; + Some(conv) + } else { + None + }; + Ok(Downsample2D { conv, padding }) + } +} + +impl Downsample2D { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + match &self.conv { + None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None), + Some(conv) => { + if self.padding == 0 { + let xs = crate::utils::pad(xs)?; // [0, 1, 0, 1], "constant", Some(0.))?; + conv.forward(&xs) + } else { + conv.forward(xs) + } + } + } + } +} + +// This does not support the conv-transpose mode. +#[derive(Debug)] +struct Upsample2D { + conv: nn::Conv2d, +} + +impl Upsample2D { + fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> { + let config = nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?; + Ok(Self { conv }) + } +} + +impl Upsample2D { + fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> { + let xs = match size { + None => { + // The following does not work and it's tricky to pass no fixed + // dimensions so hack our way around this. + // xs.upsample_nearest2d(&[], Some(2.), Some(2.) + let (_bsize, _channels, _h, _w) = xs.dims4()?; + crate::utils::upsample_nearest2d(xs)? // [2 * h, 2 * w], Some(2.), Some(2.)) + } + Some((_h, _w)) => crate::utils::upsample_nearest2d(xs)?, // [h, w], None, None), + }; + self.conv.forward(&xs) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct DownEncoderBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + pub resnet_groups: usize, + pub output_scale_factor: f64, + pub add_downsample: bool, + pub downsample_padding: usize, +} + +impl Default for DownEncoderBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: 32, + output_scale_factor: 1., + add_downsample: true, + downsample_padding: 1, + } + } +} + +#[derive(Debug)] +pub struct DownEncoderBlock2D { + resnets: Vec<ResnetBlock2D>, + downsampler: Option<Downsample2D>, + pub config: DownEncoderBlock2DConfig, +} + +impl DownEncoderBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + config: DownEncoderBlock2DConfig, + ) -> Result<Self> { + let resnets: Vec<_> = { + let vs = vs.pp("resnets"); + let conv_cfg = ResnetBlock2DConfig { + eps: config.resnet_eps, + out_channels: Some(out_channels), + groups: config.resnet_groups, + output_scale_factor: config.output_scale_factor, + temb_channels: None, + ..Default::default() + }; + (0..(config.num_layers)) + .map(|i| { + let in_channels = if i == 0 { in_channels } else { out_channels }; + ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg) + }) + .collect::<Result<Vec<_>>>()? + }; + let downsampler = if config.add_downsample { + let downsample = Downsample2D::new( + vs.pp("downsamplers").pp("0"), + out_channels, + true, + out_channels, + config.downsample_padding, + )?; + Some(downsample) + } else { + None + }; + Ok(Self { + resnets, + downsampler, + config, + }) + } +} + +impl DownEncoderBlock2D { + pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = xs.clone(); + for resnet in self.resnets.iter() { + xs = resnet.forward(&xs, None)? + } + match &self.downsampler { + Some(downsampler) => downsampler.forward(&xs), + None => Ok(xs), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct UpDecoderBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + pub resnet_groups: usize, + pub output_scale_factor: f64, + pub add_upsample: bool, +} + +impl Default for UpDecoderBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: 32, + output_scale_factor: 1., + add_upsample: true, + } + } +} + +#[derive(Debug)] +pub struct UpDecoderBlock2D { + resnets: Vec<ResnetBlock2D>, + upsampler: Option<Upsample2D>, + pub config: UpDecoderBlock2DConfig, +} + +impl UpDecoderBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + config: UpDecoderBlock2DConfig, + ) -> Result<Self> { + let resnets: Vec<_> = { + let vs = vs.pp("resnets"); + let conv_cfg = ResnetBlock2DConfig { + out_channels: Some(out_channels), + eps: config.resnet_eps, + groups: config.resnet_groups, + output_scale_factor: config.output_scale_factor, + temb_channels: None, + ..Default::default() + }; + (0..(config.num_layers)) + .map(|i| { + let in_channels = if i == 0 { in_channels } else { out_channels }; + ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg) + }) + .collect::<Result<Vec<_>>>()? + }; + let upsampler = if config.add_upsample { + let upsample = + Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?; + Some(upsample) + } else { + None + }; + Ok(Self { + resnets, + upsampler, + config, + }) + } +} + +impl UpDecoderBlock2D { + pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = xs.clone(); + for resnet in self.resnets.iter() { + xs = resnet.forward(&xs, None)? + } + match &self.upsampler { + Some(upsampler) => upsampler.forward(&xs, None), + None => Ok(xs), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct UNetMidBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + pub resnet_groups: Option<usize>, + pub attn_num_head_channels: Option<usize>, + // attention_type "default" + pub output_scale_factor: f64, +} + +impl Default for UNetMidBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: Some(32), + attn_num_head_channels: Some(1), + output_scale_factor: 1., + } + } +} + +#[derive(Debug)] +pub struct UNetMidBlock2D { + resnet: ResnetBlock2D, + attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>, + pub config: UNetMidBlock2DConfig, +} + +impl UNetMidBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + temb_channels: Option<usize>, + config: UNetMidBlock2DConfig, + ) -> Result<Self> { + let vs_resnets = vs.pp("resnets"); + let vs_attns = vs.pp("attentions"); + let resnet_groups = config + .resnet_groups + .unwrap_or_else(|| usize::min(in_channels / 4, 32)); + let resnet_cfg = ResnetBlock2DConfig { + eps: config.resnet_eps, + groups: resnet_groups, + output_scale_factor: config.output_scale_factor, + temb_channels, + ..Default::default() + }; + let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?; + let attn_cfg = AttentionBlockConfig { + num_head_channels: config.attn_num_head_channels, + num_groups: resnet_groups, + rescale_output_factor: config.output_scale_factor, + eps: config.resnet_eps, + }; + let mut attn_resnets = vec![]; + for index in 0..config.num_layers { + let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?; + let resnet = ResnetBlock2D::new( + vs_resnets.pp(&(index + 1).to_string()), + in_channels, + resnet_cfg, + )?; + attn_resnets.push((attn, resnet)) + } + Ok(Self { + resnet, + attn_resnets, + config, + }) + } + + pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> { + let mut xs = self.resnet.forward(xs, temb)?; + for (attn, resnet) in self.attn_resnets.iter() { + xs = resnet.forward(&attn.forward(&xs)?, temb)? + } + Ok(xs) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct UNetMidBlock2DCrossAttnConfig { + pub num_layers: usize, + pub resnet_eps: f64, + pub resnet_groups: Option<usize>, + pub attn_num_head_channels: usize, + // attention_type "default" + pub output_scale_factor: f64, + pub cross_attn_dim: usize, + pub sliced_attention_size: Option<usize>, + pub use_linear_projection: bool, +} + +impl Default for UNetMidBlock2DCrossAttnConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: Some(32), + attn_num_head_channels: 1, + output_scale_factor: 1., + cross_attn_dim: 1280, + sliced_attention_size: None, // Sliced attention disabled + use_linear_projection: false, + } + } +} + +#[derive(Debug)] +pub struct UNetMidBlock2DCrossAttn { + resnet: ResnetBlock2D, + attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>, + pub config: UNetMidBlock2DCrossAttnConfig, +} + +impl UNetMidBlock2DCrossAttn { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + temb_channels: Option<usize>, + config: UNetMidBlock2DCrossAttnConfig, + ) -> Result<Self> { + let vs_resnets = vs.pp("resnets"); + let vs_attns = vs.pp("attentions"); + let resnet_groups = config + .resnet_groups + .unwrap_or_else(|| usize::min(in_channels / 4, 32)); + let resnet_cfg = ResnetBlock2DConfig { + eps: config.resnet_eps, + groups: resnet_groups, + output_scale_factor: config.output_scale_factor, + temb_channels, + ..Default::default() + }; + let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?; + let n_heads = config.attn_num_head_channels; + let attn_cfg = SpatialTransformerConfig { + depth: 1, + num_groups: resnet_groups, + context_dim: Some(config.cross_attn_dim), + sliced_attention_size: config.sliced_attention_size, + use_linear_projection: config.use_linear_projection, + }; + let mut attn_resnets = vec![]; + for index in 0..config.num_layers { + let attn = SpatialTransformer::new( + vs_attns.pp(&index.to_string()), + in_channels, + n_heads, + in_channels / n_heads, + attn_cfg, + )?; + let resnet = ResnetBlock2D::new( + vs_resnets.pp(&(index + 1).to_string()), + in_channels, + resnet_cfg, + )?; + attn_resnets.push((attn, resnet)) + } + Ok(Self { + resnet, + attn_resnets, + config, + }) + } + + pub fn forward( + &self, + xs: &Tensor, + temb: Option<&Tensor>, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<Tensor> { + let mut xs = self.resnet.forward(xs, temb)?; + for (attn, resnet) in self.attn_resnets.iter() { + xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)? + } + Ok(xs) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct DownBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + // resnet_time_scale_shift: "default" + // resnet_act_fn: "swish" + pub resnet_groups: usize, + pub output_scale_factor: f64, + pub add_downsample: bool, + pub downsample_padding: usize, +} + +impl Default for DownBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: 32, + output_scale_factor: 1., + add_downsample: true, + downsample_padding: 1, + } + } +} + +#[derive(Debug)] +pub struct DownBlock2D { + resnets: Vec<ResnetBlock2D>, + downsampler: Option<Downsample2D>, + pub config: DownBlock2DConfig, +} + +impl DownBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + temb_channels: Option<usize>, + config: DownBlock2DConfig, + ) -> Result<Self> { + let vs_resnets = vs.pp("resnets"); + let resnet_cfg = ResnetBlock2DConfig { + out_channels: Some(out_channels), + eps: config.resnet_eps, + output_scale_factor: config.output_scale_factor, + temb_channels, + ..Default::default() + }; + let resnets = (0..config.num_layers) + .map(|i| { + let in_channels = if i == 0 { in_channels } else { out_channels }; + ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg) + }) + .collect::<Result<Vec<_>>>()?; + let downsampler = if config.add_downsample { + let downsampler = Downsample2D::new( + vs.pp("downsamplers").pp("0"), + out_channels, + true, + out_channels, + config.downsample_padding, + )?; + Some(downsampler) + } else { + None + }; + Ok(Self { + resnets, + downsampler, + config, + }) + } + + pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> { + let mut xs = xs.clone(); + let mut output_states = vec![]; + for resnet in self.resnets.iter() { + xs = resnet.forward(&xs, temb)?; + output_states.push(xs.clone()); + } + let xs = match &self.downsampler { + Some(downsampler) => { + let xs = downsampler.forward(&xs)?; + output_states.push(xs.clone()); + xs + } + None => xs, + }; + Ok((xs, output_states)) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct CrossAttnDownBlock2DConfig { + pub downblock: DownBlock2DConfig, + pub attn_num_head_channels: usize, + pub cross_attention_dim: usize, + // attention_type: "default" + pub sliced_attention_size: Option<usize>, + pub use_linear_projection: bool, +} + +impl Default for CrossAttnDownBlock2DConfig { + fn default() -> Self { + Self { + downblock: Default::default(), + attn_num_head_channels: 1, + cross_attention_dim: 1280, + sliced_attention_size: None, + use_linear_projection: false, + } + } +} + +#[derive(Debug)] +pub struct CrossAttnDownBlock2D { + downblock: DownBlock2D, + attentions: Vec<SpatialTransformer>, + pub config: CrossAttnDownBlock2DConfig, +} + +impl CrossAttnDownBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + temb_channels: Option<usize>, + config: CrossAttnDownBlock2DConfig, + ) -> Result<Self> { + let downblock = DownBlock2D::new( + vs.clone(), + in_channels, + out_channels, + temb_channels, + config.downblock, + )?; + let n_heads = config.attn_num_head_channels; + let cfg = SpatialTransformerConfig { + depth: 1, + context_dim: Some(config.cross_attention_dim), + num_groups: config.downblock.resnet_groups, + sliced_attention_size: config.sliced_attention_size, + use_linear_projection: config.use_linear_projection, + }; + let vs_attn = vs.pp("attentions"); + let attentions = (0..config.downblock.num_layers) + .map(|i| { + SpatialTransformer::new( + vs_attn.pp(&i.to_string()), + out_channels, + n_heads, + out_channels / n_heads, + cfg, + ) + }) + .collect::<Result<Vec<_>>>()?; + Ok(Self { + downblock, + attentions, + config, + }) + } + + pub fn forward( + &self, + xs: &Tensor, + temb: Option<&Tensor>, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<(Tensor, Vec<Tensor>)> { + let mut output_states = vec![]; + let mut xs = xs.clone(); + for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) { + xs = resnet.forward(&xs, temb)?; + xs = attn.forward(&xs, encoder_hidden_states)?; + output_states.push(xs.clone()); + } + let xs = match &self.downblock.downsampler { + Some(downsampler) => { + let xs = downsampler.forward(&xs)?; + output_states.push(xs.clone()); + xs + } + None => xs, + }; + Ok((xs, output_states)) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct UpBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + // resnet_time_scale_shift: "default" + // resnet_act_fn: "swish" + pub resnet_groups: usize, + pub output_scale_factor: f64, + pub add_upsample: bool, +} + +impl Default for UpBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: 32, + output_scale_factor: 1., + add_upsample: true, + } + } +} + +#[derive(Debug)] +pub struct UpBlock2D { + pub resnets: Vec<ResnetBlock2D>, + upsampler: Option<Upsample2D>, + pub config: UpBlock2DConfig, +} + +impl UpBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + prev_output_channels: usize, + out_channels: usize, + temb_channels: Option<usize>, + config: UpBlock2DConfig, + ) -> Result<Self> { + let vs_resnets = vs.pp("resnets"); + let resnet_cfg = ResnetBlock2DConfig { + out_channels: Some(out_channels), + temb_channels, + eps: config.resnet_eps, + output_scale_factor: config.output_scale_factor, + ..Default::default() + }; + let resnets = (0..config.num_layers) + .map(|i| { + let res_skip_channels = if i == config.num_layers - 1 { + in_channels + } else { + out_channels + }; + let resnet_in_channels = if i == 0 { + prev_output_channels + } else { + out_channels + }; + let in_channels = resnet_in_channels + res_skip_channels; + ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg) + }) + .collect::<Result<Vec<_>>>()?; + let upsampler = if config.add_upsample { + let upsampler = + Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?; + Some(upsampler) + } else { + None + }; + Ok(Self { + resnets, + upsampler, + config, + }) + } + + pub fn forward( + &self, + xs: &Tensor, + res_xs: &[Tensor], + temb: Option<&Tensor>, + upsample_size: Option<(usize, usize)>, + ) -> Result<Tensor> { + let mut xs = xs.clone(); + for (index, resnet) in self.resnets.iter().enumerate() { + xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?; + xs = resnet.forward(&xs, temb)?; + } + match &self.upsampler { + Some(upsampler) => upsampler.forward(&xs, upsample_size), + None => Ok(xs), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct CrossAttnUpBlock2DConfig { + pub upblock: UpBlock2DConfig, + pub attn_num_head_channels: usize, + pub cross_attention_dim: usize, + // attention_type: "default" + pub sliced_attention_size: Option<usize>, + pub use_linear_projection: bool, +} + +impl Default for CrossAttnUpBlock2DConfig { + fn default() -> Self { + Self { + upblock: Default::default(), + attn_num_head_channels: 1, + cross_attention_dim: 1280, + sliced_attention_size: None, + use_linear_projection: false, + } + } +} + +#[derive(Debug)] +pub struct CrossAttnUpBlock2D { + pub upblock: UpBlock2D, + pub attentions: Vec<SpatialTransformer>, + pub config: CrossAttnUpBlock2DConfig, +} + +impl CrossAttnUpBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + prev_output_channels: usize, + out_channels: usize, + temb_channels: Option<usize>, + config: CrossAttnUpBlock2DConfig, + ) -> Result<Self> { + let upblock = UpBlock2D::new( + vs.clone(), + in_channels, + prev_output_channels, + out_channels, + temb_channels, + config.upblock, + )?; + let n_heads = config.attn_num_head_channels; + let cfg = SpatialTransformerConfig { + depth: 1, + context_dim: Some(config.cross_attention_dim), + num_groups: config.upblock.resnet_groups, + sliced_attention_size: config.sliced_attention_size, + use_linear_projection: config.use_linear_projection, + }; + let vs_attn = vs.pp("attentions"); + let attentions = (0..config.upblock.num_layers) + .map(|i| { + SpatialTransformer::new( + vs_attn.pp(&i.to_string()), + out_channels, + n_heads, + out_channels / n_heads, + cfg, + ) + }) + .collect::<Result<Vec<_>>>()?; + Ok(Self { + upblock, + attentions, + config, + }) + } + + pub fn forward( + &self, + xs: &Tensor, + res_xs: &[Tensor], + temb: Option<&Tensor>, + upsample_size: Option<(usize, usize)>, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<Tensor> { + let mut xs = xs.clone(); + for (index, resnet) in self.upblock.resnets.iter().enumerate() { + xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?; + xs = resnet.forward(&xs, temb)?; + xs = self.attentions[index].forward(&xs, encoder_hidden_states)?; + } + match &self.upblock.upsampler { + Some(upsampler) => upsampler.forward(&xs, upsample_size), + None => Ok(xs), + } + } +} diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs new file mode 100644 index 00000000..aa49e560 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/utils.rs @@ -0,0 +1,17 @@ +use candle::{Result, Tensor}; + +pub fn sigmoid(_: &Tensor) -> Result<Tensor> { + todo!() +} + +pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> { + todo!() +} + +pub fn pad(_: &Tensor) -> Result<Tensor> { + todo!() +} + +pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> { + todo!() +} diff --git a/candle-examples/examples/stable-diffusion/vae.rs b/candle-examples/examples/stable-diffusion/vae.rs new file mode 100644 index 00000000..7a10d932 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/vae.rs @@ -0,0 +1,378 @@ +#![allow(dead_code)] +//! # Variational Auto-Encoder (VAE) Models. +//! +//! Auto-encoder models compress their input to a usually smaller latent space +//! before expanding it back to its original shape. This results in the latent values +//! compressing the original information. +use crate::unet_2d_blocks::{ + DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig, + UpDecoderBlock2D, UpDecoderBlock2DConfig, +}; +use candle::{Result, Tensor}; +use candle_nn as nn; + +#[derive(Debug, Clone)] +struct EncoderConfig { + // down_block_types: DownEncoderBlock2D + block_out_channels: Vec<usize>, + layers_per_block: usize, + norm_num_groups: usize, + double_z: bool, +} + +impl Default for EncoderConfig { + fn default() -> Self { + Self { + block_out_channels: vec![64], + layers_per_block: 2, + norm_num_groups: 32, + double_z: true, + } + } +} + +#[derive(Debug)] +struct Encoder { + conv_in: nn::Conv2d, + down_blocks: Vec<DownEncoderBlock2D>, + mid_block: UNetMidBlock2D, + conv_norm_out: nn::GroupNorm, + conv_out: nn::Conv2d, + #[allow(dead_code)] + config: EncoderConfig, +} + +impl Encoder { + fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + config: EncoderConfig, + ) -> Result<Self> { + let conv_cfg = nn::Conv2dConfig { + stride: 1, + padding: 1, + }; + let conv_in = nn::conv2d( + in_channels, + config.block_out_channels[0], + 3, + conv_cfg, + vs.pp("conv_in"), + )?; + let mut down_blocks = vec![]; + let vs_down_blocks = vs.pp("down_blocks"); + for index in 0..config.block_out_channels.len() { + let out_channels = config.block_out_channels[index]; + let in_channels = if index > 0 { + config.block_out_channels[index - 1] + } else { + config.block_out_channels[0] + }; + let is_final = index + 1 == config.block_out_channels.len(); + let cfg = DownEncoderBlock2DConfig { + num_layers: config.layers_per_block, + resnet_eps: 1e-6, + resnet_groups: config.norm_num_groups, + add_downsample: !is_final, + downsample_padding: 0, + ..Default::default() + }; + let down_block = DownEncoderBlock2D::new( + vs_down_blocks.pp(&index.to_string()), + in_channels, + out_channels, + cfg, + )?; + down_blocks.push(down_block) + } + let last_block_out_channels = *config.block_out_channels.last().unwrap(); + let mid_cfg = UNetMidBlock2DConfig { + resnet_eps: 1e-6, + output_scale_factor: 1., + attn_num_head_channels: None, + resnet_groups: Some(config.norm_num_groups), + ..Default::default() + }; + let mid_block = + UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?; + let conv_norm_out = nn::group_norm( + config.norm_num_groups, + last_block_out_channels, + 1e-6, + vs.pp("conv_norm_out"), + )?; + let conv_out_channels = if config.double_z { + 2 * out_channels + } else { + out_channels + }; + let conv_cfg = nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let conv_out = nn::conv2d( + last_block_out_channels, + conv_out_channels, + 3, + conv_cfg, + vs.pp("conv_out"), + )?; + Ok(Self { + conv_in, + down_blocks, + mid_block, + conv_norm_out, + conv_out, + config, + }) + } +} + +impl Encoder { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = self.conv_in.forward(xs)?; + for down_block in self.down_blocks.iter() { + xs = down_block.forward(&xs)? + } + let xs = self.mid_block.forward(&xs, None)?; + let xs = self.conv_norm_out.forward(&xs)?; + let xs = nn::ops::silu(&xs)?; + self.conv_out.forward(&xs) + } +} + +#[derive(Debug, Clone)] +struct DecoderConfig { + // up_block_types: UpDecoderBlock2D + block_out_channels: Vec<usize>, + layers_per_block: usize, + norm_num_groups: usize, +} + +impl Default for DecoderConfig { + fn default() -> Self { + Self { + block_out_channels: vec![64], + layers_per_block: 2, + norm_num_groups: 32, + } + } +} + +#[derive(Debug)] +struct Decoder { + conv_in: nn::Conv2d, + up_blocks: Vec<UpDecoderBlock2D>, + mid_block: UNetMidBlock2D, + conv_norm_out: nn::GroupNorm, + conv_out: nn::Conv2d, + #[allow(dead_code)] + config: DecoderConfig, +} + +impl Decoder { + fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + config: DecoderConfig, + ) -> Result<Self> { + let n_block_out_channels = config.block_out_channels.len(); + let last_block_out_channels = *config.block_out_channels.last().unwrap(); + let conv_cfg = nn::Conv2dConfig { + stride: 1, + padding: 1, + }; + let conv_in = nn::conv2d( + in_channels, + last_block_out_channels, + 3, + conv_cfg, + vs.pp("conv_in"), + )?; + let mid_cfg = UNetMidBlock2DConfig { + resnet_eps: 1e-6, + output_scale_factor: 1., + attn_num_head_channels: None, + resnet_groups: Some(config.norm_num_groups), + ..Default::default() + }; + let mid_block = + UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?; + let mut up_blocks = vec![]; + let vs_up_blocks = vs.pp("up_blocks"); + let reversed_block_out_channels: Vec<_> = + config.block_out_channels.iter().copied().rev().collect(); + for index in 0..n_block_out_channels { + let out_channels = reversed_block_out_channels[index]; + let in_channels = if index > 0 { + reversed_block_out_channels[index - 1] + } else { + reversed_block_out_channels[0] + }; + let is_final = index + 1 == n_block_out_channels; + let cfg = UpDecoderBlock2DConfig { + num_layers: config.layers_per_block + 1, + resnet_eps: 1e-6, + resnet_groups: config.norm_num_groups, + add_upsample: !is_final, + ..Default::default() + }; + let up_block = UpDecoderBlock2D::new( + vs_up_blocks.pp(&index.to_string()), + in_channels, + out_channels, + cfg, + )?; + up_blocks.push(up_block) + } + let conv_norm_out = nn::group_norm( + config.norm_num_groups, + config.block_out_channels[0], + 1e-6, + vs.pp("conv_norm_out"), + )?; + let conv_cfg = nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let conv_out = nn::conv2d( + config.block_out_channels[0], + out_channels, + 3, + conv_cfg, + vs.pp("conv_out"), + )?; + Ok(Self { + conv_in, + up_blocks, + mid_block, + conv_norm_out, + conv_out, + config, + }) + } +} + +impl Decoder { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = self.mid_block.forward(&self.conv_in.forward(xs)?, None)?; + for up_block in self.up_blocks.iter() { + xs = up_block.forward(&xs)? + } + let xs = self.conv_norm_out.forward(&xs)?; + let xs = nn::ops::silu(&xs)?; + self.conv_out.forward(&xs) + } +} + +#[derive(Debug, Clone)] +pub struct AutoEncoderKLConfig { + pub block_out_channels: Vec<usize>, + pub layers_per_block: usize, + pub latent_channels: usize, + pub norm_num_groups: usize, +} + +impl Default for AutoEncoderKLConfig { + fn default() -> Self { + Self { + block_out_channels: vec![64], + layers_per_block: 1, + latent_channels: 4, + norm_num_groups: 32, + } + } +} + +pub struct DiagonalGaussianDistribution { + mean: Tensor, + std: Tensor, +} + +impl DiagonalGaussianDistribution { + pub fn new(parameters: &Tensor) -> Result<Self> { + let mut parameters = parameters.chunk(2, 1)?.into_iter(); + let mean = parameters.next().unwrap(); + let logvar = parameters.next().unwrap(); + let std = (logvar * 0.5)?.exp()?; + Ok(DiagonalGaussianDistribution { mean, std }) + } + + pub fn sample(&self) -> Result<Tensor> { + let sample = Tensor::randn(0., 1f32, self.mean.shape(), self.mean.device()); + &self.mean + &self.std * sample + } +} + +// https://github.com/huggingface/diffusers/blob/970e30606c2944e3286f56e8eb6d3dc6d1eb85f7/src/diffusers/models/vae.py#L485 +// This implementation is specific to the config used in stable-diffusion-v1-5 +// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json +#[derive(Debug)] +pub struct AutoEncoderKL { + encoder: Encoder, + decoder: Decoder, + quant_conv: nn::Conv2d, + post_quant_conv: nn::Conv2d, + pub config: AutoEncoderKLConfig, +} + +impl AutoEncoderKL { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + config: AutoEncoderKLConfig, + ) -> Result<Self> { + let latent_channels = config.latent_channels; + let encoder_cfg = EncoderConfig { + block_out_channels: config.block_out_channels.clone(), + layers_per_block: config.layers_per_block, + norm_num_groups: config.norm_num_groups, + double_z: true, + }; + let encoder = Encoder::new(vs.pp("encoder"), in_channels, latent_channels, encoder_cfg)?; + let decoder_cfg = DecoderConfig { + block_out_channels: config.block_out_channels.clone(), + layers_per_block: config.layers_per_block, + norm_num_groups: config.norm_num_groups, + }; + let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?; + let conv_cfg = Default::default(); + let quant_conv = nn::conv2d( + 2 * latent_channels, + 2 * latent_channels, + 1, + conv_cfg, + vs.pp("quant_conv"), + )?; + let post_quant_conv = nn::conv2d( + latent_channels, + latent_channels, + 1, + conv_cfg, + vs.pp("post_quant_conv"), + )?; + Ok(Self { + encoder, + decoder, + quant_conv, + post_quant_conv, + config, + }) + } + + /// Returns the distribution in the latent space. + pub fn encode(&self, xs: &Tensor) -> Result<DiagonalGaussianDistribution> { + let xs = self.encoder.forward(xs)?; + let parameters = self.quant_conv.forward(&xs)?; + DiagonalGaussianDistribution::new(¶meters) + } + + /// Takes as input some sampled values. + pub fn decode(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.post_quant_conv.forward(xs)?; + self.decoder.forward(&xs) + } +} |