diff options
Diffstat (limited to 'candle-examples/examples/stable-diffusion')
-rw-r--r-- | candle-examples/examples/stable-diffusion/attention.rs | 445 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/clip.rs | 305 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/ddim.rs | 181 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/embeddings.rs | 65 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/main.rs | 273 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/resnet.rs | 129 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/schedulers.rs | 45 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/stable_diffusion.rs | 212 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/unet_2d.rs | 383 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/unet_2d_blocks.rs | 808 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/utils.rs | 31 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/vae.rs | 378 |
12 files changed, 3255 insertions, 0 deletions
diff --git a/candle-examples/examples/stable-diffusion/attention.rs b/candle-examples/examples/stable-diffusion/attention.rs new file mode 100644 index 00000000..83e7ef34 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/attention.rs @@ -0,0 +1,445 @@ +#![allow(dead_code)] +//! Attention Based Building Blocks +use candle::{IndexOp, Result, Tensor, D}; +use candle_nn as nn; + +#[derive(Debug)] +struct GeGlu { + proj: nn::Linear, +} + +impl GeGlu { + fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> { + let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?; + Ok(Self { proj }) + } +} + +impl GeGlu { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?; + &hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()? + } +} + +/// A feed-forward layer. +#[derive(Debug)] +struct FeedForward { + project_in: GeGlu, + linear: nn::Linear, +} + +impl FeedForward { + // The glu parameter in the python code is unused? + // https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L347 + /// Creates a new feed-forward layer based on some given input dimension, some + /// output dimension, and a multiplier to be used for the intermediary layer. + fn new(vs: nn::VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> { + let inner_dim = dim * mult; + let dim_out = dim_out.unwrap_or(dim); + let vs = vs.pp("net"); + let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?; + let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?; + Ok(Self { project_in, linear }) + } +} + +impl FeedForward { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.project_in.forward(xs)?; + self.linear.forward(&xs) + } +} + +#[derive(Debug)] +struct CrossAttention { + to_q: nn::Linear, + to_k: nn::Linear, + to_v: nn::Linear, + to_out: nn::Linear, + heads: usize, + scale: f64, + slice_size: Option<usize>, +} + +impl CrossAttention { + // Defaults should be heads = 8, dim_head = 64, context_dim = None + fn new( + vs: nn::VarBuilder, + query_dim: usize, + context_dim: Option<usize>, + heads: usize, + dim_head: usize, + slice_size: Option<usize>, + ) -> Result<Self> { + let inner_dim = dim_head * heads; + let context_dim = context_dim.unwrap_or(query_dim); + let scale = 1.0 / f64::sqrt(dim_head as f64); + let to_q = nn::linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?; + let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?; + let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?; + let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?; + Ok(Self { + to_q, + to_k, + to_v, + to_out, + heads, + scale, + slice_size, + }) + } + + fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> { + let (batch_size, seq_len, dim) = xs.dims3()?; + xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))? + .transpose(1, 2)? + .reshape((batch_size * self.heads, seq_len, dim / self.heads)) + } + + fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> { + let (batch_size, seq_len, dim) = xs.dims3()?; + xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))? + .transpose(1, 2)? + .reshape((batch_size / self.heads, seq_len, dim * self.heads)) + } + + fn sliced_attention( + &self, + query: &Tensor, + key: &Tensor, + value: &Tensor, + slice_size: usize, + ) -> Result<Tensor> { + let batch_size_attention = query.dim(0)?; + let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size); + + for i in 0..batch_size_attention / slice_size { + let start_idx = i * slice_size; + let end_idx = (i + 1) * slice_size; + + let xs = query + .i(start_idx..end_idx)? + .matmul(&(key.i(start_idx..end_idx)?.t()? * self.scale)?)?; + let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?; + hidden_states.push(xs) + } + let hidden_states = Tensor::stack(&hidden_states, 0)?; + self.reshape_batch_dim_to_heads(&hidden_states) + } + + fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> { + let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?; + let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?; + self.reshape_batch_dim_to_heads(&xs) + } + + fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> { + let query = self.to_q.forward(xs)?; + let context = context.unwrap_or(xs); + let key = self.to_k.forward(context)?; + let value = self.to_v.forward(context)?; + let query = self.reshape_heads_to_batch_dim(&query)?; + let key = self.reshape_heads_to_batch_dim(&key)?; + let value = self.reshape_heads_to_batch_dim(&value)?; + let xs = match self.slice_size { + None => self.attention(&query, &key, &value)?, + Some(slice_size) => { + if query.dim(0)? / slice_size <= 1 { + self.attention(&query, &key, &value)? + } else { + self.sliced_attention(&query, &key, &value, slice_size)? + } + } + }; + self.to_out.forward(&xs) + } +} + +/// A basic Transformer block. +#[derive(Debug)] +struct BasicTransformerBlock { + attn1: CrossAttention, + ff: FeedForward, + attn2: CrossAttention, + norm1: nn::LayerNorm, + norm2: nn::LayerNorm, + norm3: nn::LayerNorm, +} + +impl BasicTransformerBlock { + fn new( + vs: nn::VarBuilder, + dim: usize, + n_heads: usize, + d_head: usize, + context_dim: Option<usize>, + sliced_attention_size: Option<usize>, + ) -> Result<Self> { + let attn1 = CrossAttention::new( + vs.pp("attn1"), + dim, + None, + n_heads, + d_head, + sliced_attention_size, + )?; + let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?; + let attn2 = CrossAttention::new( + vs.pp("attn2"), + dim, + context_dim, + n_heads, + d_head, + sliced_attention_size, + )?; + let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?; + let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?; + let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?; + Ok(Self { + attn1, + ff, + attn2, + norm1, + norm2, + norm3, + }) + } + + fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> { + let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?; + let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?; + self.ff.forward(&self.norm3.forward(&xs)?)? + xs + } +} + +#[derive(Debug, Clone, Copy)] +pub struct SpatialTransformerConfig { + pub depth: usize, + pub num_groups: usize, + pub context_dim: Option<usize>, + pub sliced_attention_size: Option<usize>, + pub use_linear_projection: bool, +} + +impl Default for SpatialTransformerConfig { + fn default() -> Self { + Self { + depth: 1, + num_groups: 32, + context_dim: None, + sliced_attention_size: None, + use_linear_projection: false, + } + } +} + +#[derive(Debug)] +enum Proj { + Conv2d(nn::Conv2d), + Linear(nn::Linear), +} + +// Aka Transformer2DModel +#[derive(Debug)] +pub struct SpatialTransformer { + norm: nn::GroupNorm, + proj_in: Proj, + transformer_blocks: Vec<BasicTransformerBlock>, + proj_out: Proj, + pub config: SpatialTransformerConfig, +} + +impl SpatialTransformer { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + n_heads: usize, + d_head: usize, + config: SpatialTransformerConfig, + ) -> Result<Self> { + let inner_dim = n_heads * d_head; + let norm = nn::group_norm(config.num_groups, in_channels, 1e-6, vs.pp("norm"))?; + let proj_in = if config.use_linear_projection { + Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_in"))?) + } else { + Proj::Conv2d(nn::conv2d( + in_channels, + inner_dim, + 1, + Default::default(), + vs.pp("proj_in"), + )?) + }; + let mut transformer_blocks = vec![]; + let vs_tb = vs.pp("transformer_blocks"); + for index in 0..config.depth { + let tb = BasicTransformerBlock::new( + vs_tb.pp(&index.to_string()), + inner_dim, + n_heads, + d_head, + config.context_dim, + config.sliced_attention_size, + )?; + transformer_blocks.push(tb) + } + let proj_out = if config.use_linear_projection { + Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_out"))?) + } else { + Proj::Conv2d(nn::conv2d( + inner_dim, + in_channels, + 1, + Default::default(), + vs.pp("proj_out"), + )?) + }; + Ok(Self { + norm, + proj_in, + transformer_blocks, + proj_out, + config, + }) + } + + pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> { + let (batch, _channel, height, weight) = xs.dims4()?; + let residual = xs; + let xs = self.norm.forward(xs)?; + let (inner_dim, xs) = match &self.proj_in { + Proj::Conv2d(p) => { + let xs = p.forward(&xs)?; + let inner_dim = xs.dim(1)?; + let xs = xs + .transpose(1, 2)? + .t()? + .reshape((batch, height * weight, inner_dim))?; + (inner_dim, xs) + } + Proj::Linear(p) => { + let inner_dim = xs.dim(1)?; + let xs = xs + .transpose(1, 2)? + .t()? + .reshape((batch, height * weight, inner_dim))?; + (inner_dim, p.forward(&xs)?) + } + }; + let mut xs = xs; + for block in self.transformer_blocks.iter() { + xs = block.forward(&xs, context)? + } + let xs = match &self.proj_out { + Proj::Conv2d(p) => p.forward( + &xs.reshape((batch, height, weight, inner_dim))? + .t()? + .transpose(1, 2)?, + )?, + Proj::Linear(p) => p + .forward(&xs)? + .reshape((batch, height, weight, inner_dim))? + .t()? + .transpose(1, 2)?, + }; + xs + residual + } +} + +/// Configuration for an attention block. +#[derive(Debug, Clone, Copy)] +pub struct AttentionBlockConfig { + pub num_head_channels: Option<usize>, + pub num_groups: usize, + pub rescale_output_factor: f64, + pub eps: f64, +} + +impl Default for AttentionBlockConfig { + fn default() -> Self { + Self { + num_head_channels: None, + num_groups: 32, + rescale_output_factor: 1., + eps: 1e-5, + } + } +} + +#[derive(Debug)] +pub struct AttentionBlock { + group_norm: nn::GroupNorm, + query: nn::Linear, + key: nn::Linear, + value: nn::Linear, + proj_attn: nn::Linear, + channels: usize, + num_heads: usize, + config: AttentionBlockConfig, +} + +impl AttentionBlock { + pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result<Self> { + let num_head_channels = config.num_head_channels.unwrap_or(channels); + let num_heads = channels / num_head_channels; + let group_norm = + nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?; + let query = nn::linear(channels, channels, vs.pp("query"))?; + let key = nn::linear(channels, channels, vs.pp("key"))?; + let value = nn::linear(channels, channels, vs.pp("value"))?; + let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?; + Ok(Self { + group_norm, + query, + key, + value, + proj_attn, + channels, + num_heads, + config, + }) + } + + fn transpose_for_scores(&self, xs: Tensor) -> Result<Tensor> { + let (batch, t, h_times_d) = xs.dims3()?; + xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))? + .transpose(1, 2) + } +} + +impl AttentionBlock { + pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let residual = xs; + let (batch, channel, height, width) = xs.dims4()?; + let xs = self + .group_norm + .forward(xs)? + .reshape((batch, channel, height * width))? + .transpose(1, 2)?; + + let query_proj = self.query.forward(&xs)?; + let key_proj = self.key.forward(&xs)?; + let value_proj = self.value.forward(&xs)?; + + let query_states = self.transpose_for_scores(query_proj)?; + let key_states = self.transpose_for_scores(key_proj)?; + let value_states = self.transpose_for_scores(value_proj)?; + + let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25); + let attention_scores = + // TODO: Check that this needs two multiplication by `scale`. + (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?; + let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?; + + let xs = attention_probs.matmul(&value_states)?; + let xs = xs.transpose(1, 2)?.contiguous()?; + let xs = xs.flatten_from(D::Minus2)?; + let xs = self + .proj_attn + .forward(&xs)? + .t()? + .reshape((batch, channel, height, width))?; + (xs + residual)? / self.config.rescale_output_factor + } +} diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs new file mode 100644 index 00000000..ca00b417 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/clip.rs @@ -0,0 +1,305 @@ +#![allow(dead_code)] +//! Contrastive Language-Image Pre-Training +//! +//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on +//! pairs of images with related texts. +//! +//! https://github.com/openai/CLIP +use candle::{Device, Result, Tensor, D}; +use candle_nn as nn; + +#[derive(Debug, Clone, Copy)] +pub enum Activation { + QuickGelu, + Gelu, +} + +impl Activation { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + match self { + Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?, + Activation::Gelu => xs.gelu(), + } + } +} + +#[derive(Debug, Clone)] +pub struct Config { + vocab_size: usize, + embed_dim: usize, // aka config.hidden_size + activation: Activation, // aka config.hidden_act + intermediate_size: usize, + pub max_position_embeddings: usize, + // The character to use for padding, use EOS when not set. + pub pad_with: Option<String>, + num_hidden_layers: usize, + num_attention_heads: usize, + #[allow(dead_code)] + projection_dim: usize, +} + +impl Config { + // The config details can be found in the "text_config" section of this json file: + // https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json + pub fn v1_5() -> Self { + Self { + vocab_size: 49408, + embed_dim: 768, + intermediate_size: 3072, + max_position_embeddings: 77, + pad_with: None, + num_hidden_layers: 12, + num_attention_heads: 12, + projection_dim: 768, + activation: Activation::QuickGelu, + } + } + + // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/text_encoder/config.json + pub fn v2_1() -> Self { + Self { + vocab_size: 49408, + embed_dim: 1024, + intermediate_size: 4096, + max_position_embeddings: 77, + pad_with: Some("!".to_string()), + num_hidden_layers: 23, + num_attention_heads: 16, + projection_dim: 512, + activation: Activation::Gelu, + } + } +} + +// CLIP Text Model +// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py +#[derive(Debug)] +struct ClipTextEmbeddings { + token_embedding: candle_nn::Embedding, + position_embedding: candle_nn::Embedding, + position_ids: Tensor, +} + +impl ClipTextEmbeddings { + fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> { + let token_embedding = + candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?; + let position_embedding = candle_nn::embedding( + c.max_position_embeddings, + c.embed_dim, + vs.pp("position_embedding"), + )?; + let position_ids = + Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?; + Ok(ClipTextEmbeddings { + token_embedding, + position_embedding, + position_ids, + }) + } +} + +impl ClipTextEmbeddings { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let token_embedding = self.token_embedding.forward(xs)?; + let position_embedding = self.position_embedding.forward(&self.position_ids)?; + token_embedding.broadcast_add(&position_embedding) + } +} + +#[derive(Debug)] +struct ClipAttention { + k_proj: candle_nn::Linear, + v_proj: candle_nn::Linear, + q_proj: candle_nn::Linear, + out_proj: candle_nn::Linear, + head_dim: usize, + scale: f64, + num_attention_heads: usize, +} + +impl ClipAttention { + fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> { + let embed_dim = c.embed_dim; + let num_attention_heads = c.num_attention_heads; + let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?; + let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?; + let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?; + let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?; + let head_dim = embed_dim / num_attention_heads; + let scale = (head_dim as f64).powf(-0.5); + Ok(ClipAttention { + k_proj, + v_proj, + q_proj, + out_proj, + head_dim, + scale, + num_attention_heads, + }) + } + + fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> { + xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous() + } + + fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> { + let (bsz, seq_len, embed_dim) = xs.dims3()?; + let query_states = (self.q_proj.forward(xs)? * self.scale)?; + let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim); + let query_states = self + .shape(&query_states, seq_len, bsz)? + .reshape(proj_shape)?; + let key_states = self + .shape(&self.k_proj.forward(xs)?, seq_len, bsz)? + .reshape(proj_shape)?; + let value_states = self + .shape(&self.v_proj.forward(xs)?, seq_len, bsz)? + .reshape(proj_shape)?; + let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?; + + let src_len = key_states.dim(1)?; + let attn_weights = attn_weights + .reshape((bsz, self.num_attention_heads, seq_len, src_len))? + .broadcast_add(causal_attention_mask)?; + let attn_weights = + attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?; + let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; + + let attn_output = attn_weights.matmul(&value_states)?; + let attn_output = attn_output + .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))? + .transpose(1, 2)? + .reshape((bsz, seq_len, embed_dim))?; + self.out_proj.forward(&attn_output) + } +} + +#[derive(Debug)] +struct ClipMlp { + fc1: candle_nn::Linear, + fc2: candle_nn::Linear, + activation: Activation, +} + +impl ClipMlp { + fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> { + let fc1 = candle_nn::linear(c.embed_dim, c.intermediate_size, vs.pp("fc1"))?; + let fc2 = candle_nn::linear(c.intermediate_size, c.embed_dim, vs.pp("fc2"))?; + Ok(ClipMlp { + fc1, + fc2, + activation: c.activation, + }) + } +} + +impl ClipMlp { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.fc1.forward(xs)?; + self.fc2.forward(&self.activation.forward(&xs)?) + } +} + +#[derive(Debug)] +struct ClipEncoderLayer { + self_attn: ClipAttention, + layer_norm1: candle_nn::LayerNorm, + mlp: ClipMlp, + layer_norm2: candle_nn::LayerNorm, +} + +impl ClipEncoderLayer { + fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> { + let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?; + let layer_norm1 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm1"))?; + let mlp = ClipMlp::new(vs.pp("mlp"), c)?; + let layer_norm2 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm2"))?; + Ok(ClipEncoderLayer { + self_attn, + layer_norm1, + mlp, + layer_norm2, + }) + } + + fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> { + let residual = xs; + let xs = self.layer_norm1.forward(xs)?; + let xs = self.self_attn.forward(&xs, causal_attention_mask)?; + let xs = (xs + residual)?; + + let residual = &xs; + let xs = self.layer_norm2.forward(&xs)?; + let xs = self.mlp.forward(&xs)?; + xs + residual + } +} + +#[derive(Debug)] +struct ClipEncoder { + layers: Vec<ClipEncoderLayer>, +} + +impl ClipEncoder { + fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> { + let vs = vs.pp("layers"); + let mut layers: Vec<ClipEncoderLayer> = Vec::new(); + for index in 0..c.num_hidden_layers { + let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?; + layers.push(layer) + } + Ok(ClipEncoder { layers }) + } + + fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> { + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs, causal_attention_mask)?; + } + Ok(xs) + } +} + +/// A CLIP transformer based model. +#[derive(Debug)] +pub struct ClipTextTransformer { + embeddings: ClipTextEmbeddings, + encoder: ClipEncoder, + final_layer_norm: candle_nn::LayerNorm, +} + +impl ClipTextTransformer { + pub fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> { + let vs = vs.pp("text_model"); + let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?; + let encoder = ClipEncoder::new(vs.pp("encoder"), c)?; + let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?; + Ok(ClipTextTransformer { + embeddings, + encoder, + final_layer_norm, + }) + } + + // https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678 + fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> { + let mask: Vec<_> = (0..seq_len) + .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. })) + .collect(); + let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?; + mask.broadcast_as((bsz, seq_len, seq_len)) + } +} + +impl ClipTextTransformer { + pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let (bsz, seq_len) = xs.dims2()?; + let xs = self.embeddings.forward(xs)?; + let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?; + let xs = self.encoder.forward(&xs, &causal_attention_mask)?; + self.final_layer_norm.forward(&xs) + } +} diff --git a/candle-examples/examples/stable-diffusion/ddim.rs b/candle-examples/examples/stable-diffusion/ddim.rs new file mode 100644 index 00000000..6eb6df44 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/ddim.rs @@ -0,0 +1,181 @@ +#![allow(dead_code)] +//! # Denoising Diffusion Implicit Models +//! +//! The Denoising Diffusion Implicit Models (DDIM) is a simple scheduler +//! similar to Denoising Diffusion Probabilistic Models (DDPM). The DDPM +//! generative process is the reverse of a Markovian process, DDIM generalizes +//! this to non-Markovian guidance. +//! +//! Denoising Diffusion Implicit Models, J. Song et al, 2020. +//! https://arxiv.org/abs/2010.02502 +use crate::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType}; +use candle::{Result, Tensor}; + +/// The configuration for the DDIM scheduler. +#[derive(Debug, Clone, Copy)] +pub struct DDIMSchedulerConfig { + /// The value of beta at the beginning of training. + pub beta_start: f64, + /// The value of beta at the end of training. + pub beta_end: f64, + /// How beta evolved during training. + pub beta_schedule: BetaSchedule, + /// The amount of noise to be added at each step. + pub eta: f64, + /// Adjust the indexes of the inference schedule by this value. + pub steps_offset: usize, + /// prediction type of the scheduler function, one of `epsilon` (predicting + /// the noise of the diffusion process), `sample` (directly predicting the noisy sample`) + /// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + pub prediction_type: PredictionType, + /// number of diffusion steps used to train the model + pub train_timesteps: usize, +} + +impl Default for DDIMSchedulerConfig { + fn default() -> Self { + Self { + beta_start: 0.00085f64, + beta_end: 0.012f64, + beta_schedule: BetaSchedule::ScaledLinear, + eta: 0., + steps_offset: 1, + prediction_type: PredictionType::Epsilon, + train_timesteps: 1000, + } + } +} + +/// The DDIM scheduler. +#[derive(Debug, Clone)] +pub struct DDIMScheduler { + timesteps: Vec<usize>, + alphas_cumprod: Vec<f64>, + step_ratio: usize, + init_noise_sigma: f64, + pub config: DDIMSchedulerConfig, +} + +// clip_sample: False, set_alpha_to_one: False +impl DDIMScheduler { + /// Creates a new DDIM scheduler given the number of steps to be + /// used for inference as well as the number of steps that was used + /// during training. + pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> { + let step_ratio = config.train_timesteps / inference_steps; + let timesteps: Vec<usize> = (0..(inference_steps)) + .map(|s| s * step_ratio + config.steps_offset) + .rev() + .collect(); + let betas = match config.beta_schedule { + BetaSchedule::ScaledLinear => crate::utils::linspace( + config.beta_start.sqrt(), + config.beta_end.sqrt(), + config.train_timesteps, + )? + .sqr()?, + BetaSchedule::Linear => { + crate::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)? + } + BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?, + }; + let betas = betas.to_vec1::<f64>()?; + let mut alphas_cumprod = Vec::with_capacity(betas.len()); + for &beta in betas.iter() { + let alpha = 1.0 - beta; + alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64)) + } + Ok(Self { + alphas_cumprod, + timesteps, + step_ratio, + init_noise_sigma: 1., + config, + }) + } + + pub fn timesteps(&self) -> &[usize] { + self.timesteps.as_slice() + } + + /// Ensures interchangeability with schedulers that need to scale the denoising model input + /// depending on the current timestep. + pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> { + Ok(sample) + } + + /// Performs a backward step during inference. + pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> { + let timestep = if timestep >= self.alphas_cumprod.len() { + timestep - 1 + } else { + timestep + }; + // https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195 + let prev_timestep = if timestep > self.step_ratio { + timestep - self.step_ratio + } else { + 0 + }; + + let alpha_prod_t = self.alphas_cumprod[timestep]; + let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]; + let beta_prod_t = 1. - alpha_prod_t; + let beta_prod_t_prev = 1. - alpha_prod_t_prev; + + let (pred_original_sample, pred_epsilon) = match self.config.prediction_type { + PredictionType::Epsilon => { + let pred_original_sample = ((sample - (model_output * beta_prod_t.sqrt())?)? + * (1. / alpha_prod_t.sqrt()))?; + (pred_original_sample, model_output.clone()) + } + PredictionType::VPrediction => { + let pred_original_sample = + ((sample * alpha_prod_t.sqrt())? - (model_output * beta_prod_t.sqrt())?)?; + let pred_epsilon = + ((model_output * alpha_prod_t.sqrt())? + (sample * beta_prod_t.sqrt())?)?; + (pred_original_sample, pred_epsilon) + } + PredictionType::Sample => { + let pred_original_sample = model_output.clone(); + let pred_epsilon = ((sample - &pred_original_sample * alpha_prod_t.sqrt())? + * (1. / beta_prod_t.sqrt()))?; + (pred_original_sample, pred_epsilon) + } + }; + + let variance = (beta_prod_t_prev / beta_prod_t) * (1. - alpha_prod_t / alpha_prod_t_prev); + let std_dev_t = self.config.eta * variance.sqrt(); + + let pred_sample_direction = + (pred_epsilon * (1. - alpha_prod_t_prev - std_dev_t * std_dev_t).sqrt())?; + let prev_sample = + ((pred_original_sample * alpha_prod_t_prev.sqrt())? + pred_sample_direction)?; + if self.config.eta > 0. { + &prev_sample + + Tensor::randn( + 0f32, + std_dev_t as f32, + prev_sample.shape(), + prev_sample.device(), + )? + } else { + Ok(prev_sample) + } + } + + pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> { + let timestep = if timestep >= self.alphas_cumprod.len() { + timestep - 1 + } else { + timestep + }; + let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt(); + let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt(); + (original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)? + } + + pub fn init_noise_sigma(&self) -> f64 { + self.init_noise_sigma + } +} diff --git a/candle-examples/examples/stable-diffusion/embeddings.rs b/candle-examples/examples/stable-diffusion/embeddings.rs new file mode 100644 index 00000000..e3a339f5 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/embeddings.rs @@ -0,0 +1,65 @@ +#![allow(dead_code)] +use candle::{Result, Tensor, D}; +use candle_nn as nn; + +#[derive(Debug)] +pub struct TimestepEmbedding { + linear_1: nn::Linear, + linear_2: nn::Linear, +} + +impl TimestepEmbedding { + // act_fn: "silu" + pub fn new(vs: nn::VarBuilder, channel: usize, time_embed_dim: usize) -> Result<Self> { + let linear_1 = nn::linear(channel, time_embed_dim, vs.pp("linear_1"))?; + let linear_2 = nn::linear(time_embed_dim, time_embed_dim, vs.pp("linear_2"))?; + Ok(Self { linear_1, linear_2 }) + } +} + +impl TimestepEmbedding { + pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?; + self.linear_2.forward(&xs) + } +} + +#[derive(Debug)] +pub struct Timesteps { + num_channels: usize, + flip_sin_to_cos: bool, + downscale_freq_shift: f64, +} + +impl Timesteps { + pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self { + Self { + num_channels, + flip_sin_to_cos, + downscale_freq_shift, + } + } +} + +impl Timesteps { + pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let half_dim = (self.num_channels / 2) as u32; + let exponent = + (Tensor::arange(0, half_dim, xs.device())?.to_dtype(xs.dtype())? * -f64::ln(10000.))?; + let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?; + let emb = exponent.exp()?; + // emb = timesteps[:, None].float() * emb[None, :] + let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?; + let (cos, sin) = (emb.cos()?, emb.sin()?); + let emb = if self.flip_sin_to_cos { + Tensor::cat(&[&cos, &sin], D::Minus1)? + } else { + Tensor::cat(&[&sin, &cos], D::Minus1)? + }; + if self.num_channels % 2 == 1 { + emb.pad_with_zeros(D::Minus2, 0, 1) + } else { + Ok(emb) + } + } +} diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs new file mode 100644 index 00000000..8ce0c234 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -0,0 +1,273 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +mod attention; +mod clip; +mod ddim; +mod embeddings; +mod resnet; +mod schedulers; +mod stable_diffusion; +mod unet_2d; +mod unet_2d_blocks; +mod utils; +mod vae; + +use anyhow::{Error as E, Result}; +use candle::{DType, Device, Tensor}; +use clap::Parser; +use tokenizers::Tokenizer; + +const GUIDANCE_SCALE: f64 = 7.5; + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +struct Args { + /// The prompt to be used for image generation. + #[arg( + long, + default_value = "A very realistic photo of a rusty robot walking on a sandy beach" + )] + prompt: String, + + #[arg(long, default_value = "")] + uncond_prompt: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// The height in pixels of the generated image. + #[arg(long)] + height: Option<usize>, + + /// The width in pixels of the generated image. + #[arg(long)] + width: Option<usize>, + + /// The UNet weight file, in .ot or .safetensors format. + #[arg(long, value_name = "FILE")] + unet_weights: Option<String>, + + /// The CLIP weight file, in .ot or .safetensors format. + #[arg(long, value_name = "FILE")] + clip_weights: Option<String>, + + /// The VAE weight file, in .ot or .safetensors format. + #[arg(long, value_name = "FILE")] + vae_weights: Option<String>, + + #[arg(long, value_name = "FILE")] + /// The file specifying the tokenizer to used for tokenization. + tokenizer: String, + + /// The size of the sliced attention or 0 for automatic slicing (disabled by default) + #[arg(long)] + sliced_attention_size: Option<usize>, + + /// The number of steps to run the diffusion for. + #[arg(long, default_value_t = 30)] + n_steps: usize, + + /// The number of samples to generate. + #[arg(long, default_value_t = 1)] + num_samples: i64, + + /// The name of the final image to generate. + #[arg(long, value_name = "FILE", default_value = "sd_final.png")] + final_image: String, + + #[arg(long, value_enum, default_value = "v2-1")] + sd_version: StableDiffusionVersion, + + /// Generate intermediary images at each step. + #[arg(long, action)] + intermediary_images: bool, +} + +#[derive(Debug, Clone, Copy, clap::ValueEnum)] +enum StableDiffusionVersion { + V1_5, + V2_1, +} + +impl Args { + fn clip_weights(&self) -> String { + match &self.clip_weights { + Some(w) => w.clone(), + None => match self.sd_version { + StableDiffusionVersion::V1_5 => "data/pytorch_model.safetensors".to_string(), + StableDiffusionVersion::V2_1 => "data/clip_v2.1.safetensors".to_string(), + }, + } + } + + fn vae_weights(&self) -> String { + match &self.vae_weights { + Some(w) => w.clone(), + None => match self.sd_version { + StableDiffusionVersion::V1_5 => "data/vae.safetensors".to_string(), + StableDiffusionVersion::V2_1 => "data/vae_v2.1.safetensors".to_string(), + }, + } + } + + fn unet_weights(&self) -> String { + match &self.unet_weights { + Some(w) => w.clone(), + None => match self.sd_version { + StableDiffusionVersion::V1_5 => "data/unet.safetensors".to_string(), + StableDiffusionVersion::V2_1 => "data/unet_v2.1.safetensors".to_string(), + }, + } + } +} + +fn output_filename( + basename: &str, + sample_idx: i64, + num_samples: i64, + timestep_idx: Option<usize>, +) -> String { + let filename = if num_samples > 1 { + match basename.rsplit_once('.') { + None => format!("{basename}.{sample_idx}.png"), + Some((filename_no_extension, extension)) => { + format!("{filename_no_extension}.{sample_idx}.{extension}") + } + } + } else { + basename.to_string() + }; + match timestep_idx { + None => filename, + Some(timestep_idx) => match filename.rsplit_once('.') { + None => format!("{filename}-{timestep_idx}.png"), + Some((filename_no_extension, extension)) => { + format!("{filename_no_extension}-{timestep_idx}.{extension}") + } + }, + } +} + +fn run(args: Args) -> Result<()> { + let clip_weights = args.clip_weights(); + let vae_weights = args.vae_weights(); + let unet_weights = args.unet_weights(); + let Args { + prompt, + uncond_prompt, + cpu, + height, + width, + n_steps, + tokenizer, + final_image, + sliced_attention_size, + num_samples, + sd_version, + .. + } = args; + let sd_config = match sd_version { + StableDiffusionVersion::V1_5 => { + stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width) + } + StableDiffusionVersion::V2_1 => { + stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width) + } + }; + + let scheduler = sd_config.build_scheduler(n_steps)?; + let device = candle_examples::device(cpu)?; + + let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; + let pad_id = match &sd_config.clip.pad_with { + Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(), + None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(), + }; + println!("Running with prompt \"{prompt}\"."); + let mut tokens = tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + while tokens.len() < sd_config.clip.max_position_embeddings { + tokens.push(pad_id) + } + let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; + + let mut uncond_tokens = tokenizer + .encode(uncond_prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + while uncond_tokens.len() < sd_config.clip.max_position_embeddings { + uncond_tokens.push(pad_id) + } + let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?; + + println!("Building the Clip transformer."); + let text_model = sd_config.build_clip_transformer(&clip_weights, &device)?; + let text_embeddings = text_model.forward(&tokens)?; + let uncond_embeddings = text_model.forward(&uncond_tokens)?; + let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?; + + println!("Building the autoencoder."); + let vae = sd_config.build_vae(&vae_weights, &device)?; + println!("Building the unet."); + let unet = sd_config.build_unet(&unet_weights, &device, 4)?; + + let bsize = 1; + for idx in 0..num_samples { + let mut latents = Tensor::randn( + 0f32, + 1f32, + (bsize, 4, sd_config.height / 8, sd_config.width / 8), + &device, + )?; + + // scale the initial noise by the standard deviation required by the scheduler + latents = (latents * scheduler.init_noise_sigma())?; + + for (timestep_index, ×tep) in scheduler.timesteps().iter().enumerate() { + println!("Timestep {timestep_index}/{n_steps}"); + let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; + + let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?; + let noise_pred = + unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?; + let noise_pred = noise_pred.chunk(2, 0)?; + let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); + let noise_pred = + (noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?; + latents = scheduler.step(&noise_pred, timestep, &latents)?; + + if args.intermediary_images { + let image = vae.decode(&(&latents / 0.18215)?)?; + let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; + let image = (image * 255.)?.to_dtype(DType::U8)?; + let image_filename = + output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1)); + crate::utils::save_image(&image, image_filename)? + } + } + + println!( + "Generating the final image for sample {}/{}.", + idx + 1, + num_samples + ); + let image = vae.decode(&(&latents / 0.18215)?)?; + // TODO: Add the clamping between 0 and 1. + let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; + let image = (image * 255.)?.to_dtype(DType::U8)?; + let image_filename = output_filename(&final_image, idx + 1, num_samples, None); + crate::utils::save_image(&image, image_filename)? + } + Ok(()) +} + +fn main() -> Result<()> { + let args = Args::parse(); + run(args) +} diff --git a/candle-examples/examples/stable-diffusion/resnet.rs b/candle-examples/examples/stable-diffusion/resnet.rs new file mode 100644 index 00000000..7790dcf9 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/resnet.rs @@ -0,0 +1,129 @@ +#![allow(dead_code)] +//! ResNet Building Blocks +//! +//! Some Residual Network blocks used in UNet models. +//! +//! Denoising Diffusion Implicit Models, K. He and al, 2015. +//! https://arxiv.org/abs/1512.03385 +use candle::{Result, Tensor, D}; +use candle_nn as nn; + +/// Configuration for a ResNet block. +#[derive(Debug, Clone, Copy)] +pub struct ResnetBlock2DConfig { + /// The number of output channels, defaults to the number of input channels. + pub out_channels: Option<usize>, + pub temb_channels: Option<usize>, + /// The number of groups to use in group normalization. + pub groups: usize, + pub groups_out: Option<usize>, + /// The epsilon to be used in the group normalization operations. + pub eps: f64, + /// Whether to use a 2D convolution in the skip connection. When using None, + /// such a convolution is used if the number of input channels is different from + /// the number of output channels. + pub use_in_shortcut: Option<bool>, + // non_linearity: silu + /// The final output is scaled by dividing by this value. + pub output_scale_factor: f64, +} + +impl Default for ResnetBlock2DConfig { + fn default() -> Self { + Self { + out_channels: None, + temb_channels: Some(512), + groups: 32, + groups_out: None, + eps: 1e-6, + use_in_shortcut: None, + output_scale_factor: 1., + } + } +} + +#[derive(Debug)] +pub struct ResnetBlock2D { + norm1: nn::GroupNorm, + conv1: nn::Conv2d, + norm2: nn::GroupNorm, + conv2: nn::Conv2d, + time_emb_proj: Option<nn::Linear>, + conv_shortcut: Option<nn::Conv2d>, + config: ResnetBlock2DConfig, +} + +impl ResnetBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + config: ResnetBlock2DConfig, + ) -> Result<Self> { + let out_channels = config.out_channels.unwrap_or(in_channels); + let conv_cfg = nn::Conv2dConfig { + stride: 1, + padding: 1, + }; + let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?; + let conv1 = nn::conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?; + let groups_out = config.groups_out.unwrap_or(config.groups); + let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?; + let conv2 = nn::conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?; + let use_in_shortcut = config + .use_in_shortcut + .unwrap_or(in_channels != out_channels); + let conv_shortcut = if use_in_shortcut { + let conv_cfg = nn::Conv2dConfig { + stride: 1, + padding: 0, + }; + Some(nn::conv2d( + in_channels, + out_channels, + 1, + conv_cfg, + vs.pp("conv_shortcut"), + )?) + } else { + None + }; + let time_emb_proj = match config.temb_channels { + None => None, + Some(temb_channels) => Some(nn::linear( + temb_channels, + out_channels, + vs.pp("time_emb_proj"), + )?), + }; + Ok(Self { + norm1, + conv1, + norm2, + conv2, + time_emb_proj, + config, + conv_shortcut, + }) + } + + pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> { + let shortcut_xs = match &self.conv_shortcut { + Some(conv_shortcut) => conv_shortcut.forward(xs)?, + None => xs.clone(), + }; + let xs = self.norm1.forward(xs)?; + let xs = self.conv1.forward(&nn::ops::silu(&xs)?)?; + let xs = match (temb, &self.time_emb_proj) { + (Some(temb), Some(time_emb_proj)) => time_emb_proj + .forward(&nn::ops::silu(temb)?)? + .unsqueeze(D::Minus1)? + .unsqueeze(D::Minus1)? + .broadcast_add(&xs)?, + _ => xs, + }; + let xs = self + .conv2 + .forward(&nn::ops::silu(&self.norm2.forward(&xs)?)?)?; + (shortcut_xs + xs)? / self.config.output_scale_factor + } +} diff --git a/candle-examples/examples/stable-diffusion/schedulers.rs b/candle-examples/examples/stable-diffusion/schedulers.rs new file mode 100644 index 00000000..3f6a1d72 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/schedulers.rs @@ -0,0 +1,45 @@ +#![allow(dead_code)] +//! # Diffusion pipelines and models +//! +//! Noise schedulers can be used to set the trade-off between +//! inference speed and quality. + +use candle::{Result, Tensor}; + +/// This represents how beta ranges from its minimum value to the maximum +/// during training. +#[derive(Debug, Clone, Copy)] +pub enum BetaSchedule { + /// Linear interpolation. + Linear, + /// Linear interpolation of the square root of beta. + ScaledLinear, + /// Glide cosine schedule + SquaredcosCapV2, +} + +#[derive(Debug, Clone, Copy)] +pub enum PredictionType { + Epsilon, + VPrediction, + Sample, +} + +/// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of +/// `(1-beta)` over time from `t = [0,1]`. +/// +/// Contains a function `alpha_bar` that takes an argument `t` and transforms it to the cumulative product of `(1-beta)` +/// up to that part of the diffusion process. +pub(crate) fn betas_for_alpha_bar(num_diffusion_timesteps: usize, max_beta: f64) -> Result<Tensor> { + let alpha_bar = |time_step: usize| { + f64::cos((time_step as f64 + 0.008) / 1.008 * std::f64::consts::FRAC_PI_2).powi(2) + }; + let mut betas = Vec::with_capacity(num_diffusion_timesteps); + for i in 0..num_diffusion_timesteps { + let t1 = i / num_diffusion_timesteps; + let t2 = (i + 1) / num_diffusion_timesteps; + betas.push((1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta)); + } + let betas_len = betas.len(); + Tensor::from_vec(betas, betas_len, &candle::Device::Cpu) +} diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs new file mode 100644 index 00000000..c250ed56 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs @@ -0,0 +1,212 @@ +#![allow(dead_code)] +use crate::schedulers::PredictionType; +use crate::{clip, ddim, unet_2d, vae}; +use candle::{DType, Device, Result}; +use candle_nn as nn; + +#[derive(Clone, Debug)] +pub struct StableDiffusionConfig { + pub width: usize, + pub height: usize, + pub clip: clip::Config, + autoencoder: vae::AutoEncoderKLConfig, + unet: unet_2d::UNet2DConditionModelConfig, + scheduler: ddim::DDIMSchedulerConfig, +} + +impl StableDiffusionConfig { + pub fn v1_5( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + ) -> Self { + let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig { + out_channels, + use_cross_attn, + attention_head_dim, + }; + // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json + let unet = unet_2d::UNet2DConditionModelConfig { + blocks: vec![ + bc(320, true, 8), + bc(640, true, 8), + bc(1280, true, 8), + bc(1280, false, 8), + ], + center_input_sample: false, + cross_attention_dim: 768, + downsample_padding: 1, + flip_sin_to_cos: true, + freq_shift: 0., + layers_per_block: 2, + mid_block_scale_factor: 1., + norm_eps: 1e-5, + norm_num_groups: 32, + sliced_attention_size, + use_linear_projection: false, + }; + let autoencoder = vae::AutoEncoderKLConfig { + block_out_channels: vec![128, 256, 512, 512], + layers_per_block: 2, + latent_channels: 4, + norm_num_groups: 32, + }; + let height = if let Some(height) = height { + assert_eq!(height % 8, 0, "heigh has to be divisible by 8"); + height + } else { + 512 + }; + + let width = if let Some(width) = width { + assert_eq!(width % 8, 0, "width has to be divisible by 8"); + width + } else { + 512 + }; + + Self { + width, + height, + clip: clip::Config::v1_5(), + autoencoder, + scheduler: Default::default(), + unet, + } + } + + fn v2_1_( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + prediction_type: PredictionType, + ) -> Self { + let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig { + out_channels, + use_cross_attn, + attention_head_dim, + }; + // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json + let unet = unet_2d::UNet2DConditionModelConfig { + blocks: vec![ + bc(320, true, 5), + bc(640, true, 10), + bc(1280, true, 20), + bc(1280, false, 20), + ], + center_input_sample: false, + cross_attention_dim: 1024, + downsample_padding: 1, + flip_sin_to_cos: true, + freq_shift: 0., + layers_per_block: 2, + mid_block_scale_factor: 1., + norm_eps: 1e-5, + norm_num_groups: 32, + sliced_attention_size, + use_linear_projection: true, + }; + // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/vae/config.json + let autoencoder = vae::AutoEncoderKLConfig { + block_out_channels: vec![128, 256, 512, 512], + layers_per_block: 2, + latent_channels: 4, + norm_num_groups: 32, + }; + let scheduler = ddim::DDIMSchedulerConfig { + prediction_type, + ..Default::default() + }; + + let height = if let Some(height) = height { + assert_eq!(height % 8, 0, "heigh has to be divisible by 8"); + height + } else { + 768 + }; + + let width = if let Some(width) = width { + assert_eq!(width % 8, 0, "width has to be divisible by 8"); + width + } else { + 768 + }; + + Self { + width, + height, + clip: clip::Config::v2_1(), + autoencoder, + scheduler, + unet, + } + } + + pub fn v2_1( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + ) -> Self { + // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/scheduler/scheduler_config.json + Self::v2_1_( + sliced_attention_size, + height, + width, + PredictionType::VPrediction, + ) + } + + pub fn v2_1_inpaint( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + ) -> Self { + // https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/scheduler/scheduler_config.json + // This uses a PNDM scheduler rather than DDIM but the biggest difference is the prediction + // type being "epsilon" by default and not "v_prediction". + Self::v2_1_( + sliced_attention_size, + height, + width, + PredictionType::Epsilon, + ) + } + + pub fn build_vae(&self, vae_weights: &str, device: &Device) -> Result<vae::AutoEncoderKL> { + let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? }; + let weights = weights.deserialize()?; + let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device); + // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json + let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?; + Ok(autoencoder) + } + + pub fn build_unet( + &self, + unet_weights: &str, + device: &Device, + in_channels: usize, + ) -> Result<unet_2d::UNet2DConditionModel> { + let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? }; + let weights = weights.deserialize()?; + let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device); + let unet = unet_2d::UNet2DConditionModel::new(vs_unet, in_channels, 4, self.unet.clone())?; + Ok(unet) + } + + pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> { + ddim::DDIMScheduler::new(n_steps, self.scheduler) + } + + pub fn build_clip_transformer( + &self, + clip_weights: &str, + device: &Device, + ) -> Result<clip::ClipTextTransformer> { + let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? }; + let weights = weights.deserialize()?; + let vs = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device); + let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?; + Ok(text_model) + } +} diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs new file mode 100644 index 00000000..8ebd1876 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/unet_2d.rs @@ -0,0 +1,383 @@ +#![allow(dead_code)] +//! 2D UNet Denoising Models +//! +//! The 2D Unet models take as input a noisy sample and the current diffusion +//! timestep and return a denoised version of the input. +use crate::embeddings::{TimestepEmbedding, Timesteps}; +use crate::unet_2d_blocks::*; +use candle::{DType, Result, Tensor}; +use candle_nn as nn; + +#[derive(Debug, Clone, Copy)] +pub struct BlockConfig { + pub out_channels: usize, + pub use_cross_attn: bool, + pub attention_head_dim: usize, +} + +#[derive(Debug, Clone)] +pub struct UNet2DConditionModelConfig { + pub center_input_sample: bool, + pub flip_sin_to_cos: bool, + pub freq_shift: f64, + pub blocks: Vec<BlockConfig>, + pub layers_per_block: usize, + pub downsample_padding: usize, + pub mid_block_scale_factor: f64, + pub norm_num_groups: usize, + pub norm_eps: f64, + pub cross_attention_dim: usize, + pub sliced_attention_size: Option<usize>, + pub use_linear_projection: bool, +} + +impl Default for UNet2DConditionModelConfig { + fn default() -> Self { + Self { + center_input_sample: false, + flip_sin_to_cos: true, + freq_shift: 0., + blocks: vec![ + BlockConfig { + out_channels: 320, + use_cross_attn: true, + attention_head_dim: 8, + }, + BlockConfig { + out_channels: 640, + use_cross_attn: true, + attention_head_dim: 8, + }, + BlockConfig { + out_channels: 1280, + use_cross_attn: true, + attention_head_dim: 8, + }, + BlockConfig { + out_channels: 1280, + use_cross_attn: false, + attention_head_dim: 8, + }, + ], + layers_per_block: 2, + downsample_padding: 1, + mid_block_scale_factor: 1., + norm_num_groups: 32, + norm_eps: 1e-5, + cross_attention_dim: 1280, + sliced_attention_size: None, + use_linear_projection: false, + } + } +} + +#[derive(Debug)] +pub(crate) enum UNetDownBlock { + Basic(DownBlock2D), + CrossAttn(CrossAttnDownBlock2D), +} + +#[derive(Debug)] +enum UNetUpBlock { + Basic(UpBlock2D), + CrossAttn(CrossAttnUpBlock2D), +} + +#[derive(Debug)] +pub struct UNet2DConditionModel { + conv_in: nn::Conv2d, + time_proj: Timesteps, + time_embedding: TimestepEmbedding, + down_blocks: Vec<UNetDownBlock>, + mid_block: UNetMidBlock2DCrossAttn, + up_blocks: Vec<UNetUpBlock>, + conv_norm_out: nn::GroupNorm, + conv_out: nn::Conv2d, + config: UNet2DConditionModelConfig, +} + +impl UNet2DConditionModel { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + config: UNet2DConditionModelConfig, + ) -> Result<Self> { + let n_blocks = config.blocks.len(); + let b_channels = config.blocks[0].out_channels; + let bl_channels = config.blocks.last().unwrap().out_channels; + let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim; + let time_embed_dim = b_channels * 4; + let conv_cfg = nn::Conv2dConfig { + stride: 1, + padding: 1, + }; + let conv_in = nn::conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?; + + let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift); + let time_embedding = + TimestepEmbedding::new(vs.pp("time_embedding"), b_channels, time_embed_dim)?; + + let vs_db = vs.pp("down_blocks"); + let down_blocks = (0..n_blocks) + .map(|i| { + let BlockConfig { + out_channels, + use_cross_attn, + attention_head_dim, + } = config.blocks[i]; + + // Enable automatic attention slicing if the config sliced_attention_size is set to 0. + let sliced_attention_size = match config.sliced_attention_size { + Some(0) => Some(attention_head_dim / 2), + _ => config.sliced_attention_size, + }; + + let in_channels = if i > 0 { + config.blocks[i - 1].out_channels + } else { + b_channels + }; + let db_cfg = DownBlock2DConfig { + num_layers: config.layers_per_block, + resnet_eps: config.norm_eps, + resnet_groups: config.norm_num_groups, + add_downsample: i < n_blocks - 1, + downsample_padding: config.downsample_padding, + ..Default::default() + }; + if use_cross_attn { + let config = CrossAttnDownBlock2DConfig { + downblock: db_cfg, + attn_num_head_channels: attention_head_dim, + cross_attention_dim: config.cross_attention_dim, + sliced_attention_size, + use_linear_projection: config.use_linear_projection, + }; + let block = CrossAttnDownBlock2D::new( + vs_db.pp(&i.to_string()), + in_channels, + out_channels, + Some(time_embed_dim), + config, + )?; + Ok(UNetDownBlock::CrossAttn(block)) + } else { + let block = DownBlock2D::new( + vs_db.pp(&i.to_string()), + in_channels, + out_channels, + Some(time_embed_dim), + db_cfg, + )?; + Ok(UNetDownBlock::Basic(block)) + } + }) + .collect::<Result<Vec<_>>>()?; + + let mid_cfg = UNetMidBlock2DCrossAttnConfig { + resnet_eps: config.norm_eps, + output_scale_factor: config.mid_block_scale_factor, + cross_attn_dim: config.cross_attention_dim, + attn_num_head_channels: bl_attention_head_dim, + resnet_groups: Some(config.norm_num_groups), + use_linear_projection: config.use_linear_projection, + ..Default::default() + }; + let mid_block = UNetMidBlock2DCrossAttn::new( + vs.pp("mid_block"), + bl_channels, + Some(time_embed_dim), + mid_cfg, + )?; + + let vs_ub = vs.pp("up_blocks"); + let up_blocks = (0..n_blocks) + .map(|i| { + let BlockConfig { + out_channels, + use_cross_attn, + attention_head_dim, + } = config.blocks[n_blocks - 1 - i]; + + // Enable automatic attention slicing if the config sliced_attention_size is set to 0. + let sliced_attention_size = match config.sliced_attention_size { + Some(0) => Some(attention_head_dim / 2), + _ => config.sliced_attention_size, + }; + + let prev_out_channels = if i > 0 { + config.blocks[n_blocks - i].out_channels + } else { + bl_channels + }; + let in_channels = { + let index = if i == n_blocks - 1 { + 0 + } else { + n_blocks - i - 2 + }; + config.blocks[index].out_channels + }; + let ub_cfg = UpBlock2DConfig { + num_layers: config.layers_per_block + 1, + resnet_eps: config.norm_eps, + resnet_groups: config.norm_num_groups, + add_upsample: i < n_blocks - 1, + ..Default::default() + }; + if use_cross_attn { + let config = CrossAttnUpBlock2DConfig { + upblock: ub_cfg, + attn_num_head_channels: attention_head_dim, + cross_attention_dim: config.cross_attention_dim, + sliced_attention_size, + use_linear_projection: config.use_linear_projection, + }; + let block = CrossAttnUpBlock2D::new( + vs_ub.pp(&i.to_string()), + in_channels, + prev_out_channels, + out_channels, + Some(time_embed_dim), + config, + )?; + Ok(UNetUpBlock::CrossAttn(block)) + } else { + let block = UpBlock2D::new( + vs_ub.pp(&i.to_string()), + in_channels, + prev_out_channels, + out_channels, + Some(time_embed_dim), + ub_cfg, + )?; + Ok(UNetUpBlock::Basic(block)) + } + }) + .collect::<Result<Vec<_>>>()?; + + let conv_norm_out = nn::group_norm( + config.norm_num_groups, + b_channels, + config.norm_eps, + vs.pp("conv_norm_out"), + )?; + let conv_out = nn::conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?; + Ok(Self { + conv_in, + time_proj, + time_embedding, + down_blocks, + mid_block, + up_blocks, + conv_norm_out, + conv_out, + config, + }) + } +} + +impl UNet2DConditionModel { + pub fn forward( + &self, + xs: &Tensor, + timestep: f64, + encoder_hidden_states: &Tensor, + ) -> Result<Tensor> { + self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None) + } + + pub fn forward_with_additional_residuals( + &self, + xs: &Tensor, + timestep: f64, + encoder_hidden_states: &Tensor, + down_block_additional_residuals: Option<&[Tensor]>, + mid_block_additional_residual: Option<&Tensor>, + ) -> Result<Tensor> { + let (bsize, _channels, height, width) = xs.dims4()?; + let device = xs.device(); + let n_blocks = self.config.blocks.len(); + let num_upsamplers = n_blocks - 1; + let default_overall_up_factor = 2usize.pow(num_upsamplers as u32); + let forward_upsample_size = + height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0; + // 0. center input if necessary + let xs = if self.config.center_input_sample { + ((xs * 2.0)? - 1.0)? + } else { + xs.clone() + }; + // 1. time + let emb = (Tensor::ones(bsize, DType::F32, device)? * timestep)?; + let emb = self.time_proj.forward(&emb)?; + let emb = self.time_embedding.forward(&emb)?; + // 2. pre-process + let xs = self.conv_in.forward(&xs)?; + // 3. down + let mut down_block_res_xs = vec![xs.clone()]; + let mut xs = xs; + for down_block in self.down_blocks.iter() { + let (_xs, res_xs) = match down_block { + UNetDownBlock::Basic(b) => b.forward(&xs, Some(&emb))?, + UNetDownBlock::CrossAttn(b) => { + b.forward(&xs, Some(&emb), Some(encoder_hidden_states))? + } + }; + down_block_res_xs.extend(res_xs); + xs = _xs; + } + + let new_down_block_res_xs = + if let Some(down_block_additional_residuals) = down_block_additional_residuals { + let mut v = vec![]; + // A previous version of this code had a bug because of the addition being made + // in place via += hence modifying the input of the mid block. + for (i, residuals) in down_block_additional_residuals.iter().enumerate() { + v.push((&down_block_res_xs[i] + residuals)?) + } + v + } else { + down_block_res_xs + }; + let mut down_block_res_xs = new_down_block_res_xs; + + // 4. mid + let xs = self + .mid_block + .forward(&xs, Some(&emb), Some(encoder_hidden_states))?; + let xs = match mid_block_additional_residual { + None => xs, + Some(m) => (m + xs)?, + }; + // 5. up + let mut xs = xs; + let mut upsample_size = None; + for (i, up_block) in self.up_blocks.iter().enumerate() { + let n_resnets = match up_block { + UNetUpBlock::Basic(b) => b.resnets.len(), + UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(), + }; + let res_xs = down_block_res_xs.split_off(down_block_res_xs.len() - n_resnets); + if i < n_blocks - 1 && forward_upsample_size { + let (_, _, h, w) = down_block_res_xs.last().unwrap().dims4()?; + upsample_size = Some((h, w)) + } + xs = match up_block { + UNetUpBlock::Basic(b) => b.forward(&xs, &res_xs, Some(&emb), upsample_size)?, + UNetUpBlock::CrossAttn(b) => b.forward( + &xs, + &res_xs, + Some(&emb), + upsample_size, + Some(encoder_hidden_states), + )?, + }; + } + // 6. post-process + let xs = self.conv_norm_out.forward(&xs)?; + let xs = nn::ops::silu(&xs)?; + self.conv_out.forward(&xs) + } +} diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs new file mode 100644 index 00000000..82d5fad5 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs @@ -0,0 +1,808 @@ +#![allow(dead_code)] +//! 2D UNet Building Blocks +//! +use crate::attention::{ + AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig, +}; +use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig}; +use candle::{Result, Tensor, D}; +use candle_nn as nn; + +#[derive(Debug)] +struct Downsample2D { + conv: Option<nn::Conv2d>, + padding: usize, +} + +impl Downsample2D { + fn new( + vs: nn::VarBuilder, + in_channels: usize, + use_conv: bool, + out_channels: usize, + padding: usize, + ) -> Result<Self> { + let conv = if use_conv { + let config = nn::Conv2dConfig { stride: 2, padding }; + let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?; + Some(conv) + } else { + None + }; + Ok(Downsample2D { conv, padding }) + } +} + +impl Downsample2D { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + match &self.conv { + None => xs.avg_pool2d((2, 2), (2, 2)), + Some(conv) => { + if self.padding == 0 { + let xs = xs + .pad_with_zeros(D::Minus1, 0, 1)? + .pad_with_zeros(D::Minus2, 0, 1)?; + conv.forward(&xs) + } else { + conv.forward(xs) + } + } + } + } +} + +// This does not support the conv-transpose mode. +#[derive(Debug)] +struct Upsample2D { + conv: nn::Conv2d, +} + +impl Upsample2D { + fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> { + let config = nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?; + Ok(Self { conv }) + } +} + +impl Upsample2D { + fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> { + let xs = match size { + None => { + let (_bsize, _channels, h, w) = xs.dims4()?; + xs.upsample_nearest2d(2 * h, 2 * w)? + } + Some((h, w)) => xs.upsample_nearest2d(h, w)?, + }; + self.conv.forward(&xs) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct DownEncoderBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + pub resnet_groups: usize, + pub output_scale_factor: f64, + pub add_downsample: bool, + pub downsample_padding: usize, +} + +impl Default for DownEncoderBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: 32, + output_scale_factor: 1., + add_downsample: true, + downsample_padding: 1, + } + } +} + +#[derive(Debug)] +pub struct DownEncoderBlock2D { + resnets: Vec<ResnetBlock2D>, + downsampler: Option<Downsample2D>, + pub config: DownEncoderBlock2DConfig, +} + +impl DownEncoderBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + config: DownEncoderBlock2DConfig, + ) -> Result<Self> { + let resnets: Vec<_> = { + let vs = vs.pp("resnets"); + let conv_cfg = ResnetBlock2DConfig { + eps: config.resnet_eps, + out_channels: Some(out_channels), + groups: config.resnet_groups, + output_scale_factor: config.output_scale_factor, + temb_channels: None, + ..Default::default() + }; + (0..(config.num_layers)) + .map(|i| { + let in_channels = if i == 0 { in_channels } else { out_channels }; + ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg) + }) + .collect::<Result<Vec<_>>>()? + }; + let downsampler = if config.add_downsample { + let downsample = Downsample2D::new( + vs.pp("downsamplers").pp("0"), + out_channels, + true, + out_channels, + config.downsample_padding, + )?; + Some(downsample) + } else { + None + }; + Ok(Self { + resnets, + downsampler, + config, + }) + } +} + +impl DownEncoderBlock2D { + pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = xs.clone(); + for resnet in self.resnets.iter() { + xs = resnet.forward(&xs, None)? + } + match &self.downsampler { + Some(downsampler) => downsampler.forward(&xs), + None => Ok(xs), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct UpDecoderBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + pub resnet_groups: usize, + pub output_scale_factor: f64, + pub add_upsample: bool, +} + +impl Default for UpDecoderBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: 32, + output_scale_factor: 1., + add_upsample: true, + } + } +} + +#[derive(Debug)] +pub struct UpDecoderBlock2D { + resnets: Vec<ResnetBlock2D>, + upsampler: Option<Upsample2D>, + pub config: UpDecoderBlock2DConfig, +} + +impl UpDecoderBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + config: UpDecoderBlock2DConfig, + ) -> Result<Self> { + let resnets: Vec<_> = { + let vs = vs.pp("resnets"); + let conv_cfg = ResnetBlock2DConfig { + out_channels: Some(out_channels), + eps: config.resnet_eps, + groups: config.resnet_groups, + output_scale_factor: config.output_scale_factor, + temb_channels: None, + ..Default::default() + }; + (0..(config.num_layers)) + .map(|i| { + let in_channels = if i == 0 { in_channels } else { out_channels }; + ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg) + }) + .collect::<Result<Vec<_>>>()? + }; + let upsampler = if config.add_upsample { + let upsample = + Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?; + Some(upsample) + } else { + None + }; + Ok(Self { + resnets, + upsampler, + config, + }) + } +} + +impl UpDecoderBlock2D { + pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = xs.clone(); + for resnet in self.resnets.iter() { + xs = resnet.forward(&xs, None)? + } + match &self.upsampler { + Some(upsampler) => upsampler.forward(&xs, None), + None => Ok(xs), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct UNetMidBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + pub resnet_groups: Option<usize>, + pub attn_num_head_channels: Option<usize>, + // attention_type "default" + pub output_scale_factor: f64, +} + +impl Default for UNetMidBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: Some(32), + attn_num_head_channels: Some(1), + output_scale_factor: 1., + } + } +} + +#[derive(Debug)] +pub struct UNetMidBlock2D { + resnet: ResnetBlock2D, + attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>, + pub config: UNetMidBlock2DConfig, +} + +impl UNetMidBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + temb_channels: Option<usize>, + config: UNetMidBlock2DConfig, + ) -> Result<Self> { + let vs_resnets = vs.pp("resnets"); + let vs_attns = vs.pp("attentions"); + let resnet_groups = config + .resnet_groups + .unwrap_or_else(|| usize::min(in_channels / 4, 32)); + let resnet_cfg = ResnetBlock2DConfig { + eps: config.resnet_eps, + groups: resnet_groups, + output_scale_factor: config.output_scale_factor, + temb_channels, + ..Default::default() + }; + let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?; + let attn_cfg = AttentionBlockConfig { + num_head_channels: config.attn_num_head_channels, + num_groups: resnet_groups, + rescale_output_factor: config.output_scale_factor, + eps: config.resnet_eps, + }; + let mut attn_resnets = vec![]; + for index in 0..config.num_layers { + let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?; + let resnet = ResnetBlock2D::new( + vs_resnets.pp(&(index + 1).to_string()), + in_channels, + resnet_cfg, + )?; + attn_resnets.push((attn, resnet)) + } + Ok(Self { + resnet, + attn_resnets, + config, + }) + } + + pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> { + let mut xs = self.resnet.forward(xs, temb)?; + for (attn, resnet) in self.attn_resnets.iter() { + xs = resnet.forward(&attn.forward(&xs)?, temb)? + } + Ok(xs) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct UNetMidBlock2DCrossAttnConfig { + pub num_layers: usize, + pub resnet_eps: f64, + pub resnet_groups: Option<usize>, + pub attn_num_head_channels: usize, + // attention_type "default" + pub output_scale_factor: f64, + pub cross_attn_dim: usize, + pub sliced_attention_size: Option<usize>, + pub use_linear_projection: bool, +} + +impl Default for UNetMidBlock2DCrossAttnConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: Some(32), + attn_num_head_channels: 1, + output_scale_factor: 1., + cross_attn_dim: 1280, + sliced_attention_size: None, // Sliced attention disabled + use_linear_projection: false, + } + } +} + +#[derive(Debug)] +pub struct UNetMidBlock2DCrossAttn { + resnet: ResnetBlock2D, + attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>, + pub config: UNetMidBlock2DCrossAttnConfig, +} + +impl UNetMidBlock2DCrossAttn { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + temb_channels: Option<usize>, + config: UNetMidBlock2DCrossAttnConfig, + ) -> Result<Self> { + let vs_resnets = vs.pp("resnets"); + let vs_attns = vs.pp("attentions"); + let resnet_groups = config + .resnet_groups + .unwrap_or_else(|| usize::min(in_channels / 4, 32)); + let resnet_cfg = ResnetBlock2DConfig { + eps: config.resnet_eps, + groups: resnet_groups, + output_scale_factor: config.output_scale_factor, + temb_channels, + ..Default::default() + }; + let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?; + let n_heads = config.attn_num_head_channels; + let attn_cfg = SpatialTransformerConfig { + depth: 1, + num_groups: resnet_groups, + context_dim: Some(config.cross_attn_dim), + sliced_attention_size: config.sliced_attention_size, + use_linear_projection: config.use_linear_projection, + }; + let mut attn_resnets = vec![]; + for index in 0..config.num_layers { + let attn = SpatialTransformer::new( + vs_attns.pp(&index.to_string()), + in_channels, + n_heads, + in_channels / n_heads, + attn_cfg, + )?; + let resnet = ResnetBlock2D::new( + vs_resnets.pp(&(index + 1).to_string()), + in_channels, + resnet_cfg, + )?; + attn_resnets.push((attn, resnet)) + } + Ok(Self { + resnet, + attn_resnets, + config, + }) + } + + pub fn forward( + &self, + xs: &Tensor, + temb: Option<&Tensor>, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<Tensor> { + let mut xs = self.resnet.forward(xs, temb)?; + for (attn, resnet) in self.attn_resnets.iter() { + xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)? + } + Ok(xs) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct DownBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + // resnet_time_scale_shift: "default" + // resnet_act_fn: "swish" + pub resnet_groups: usize, + pub output_scale_factor: f64, + pub add_downsample: bool, + pub downsample_padding: usize, +} + +impl Default for DownBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: 32, + output_scale_factor: 1., + add_downsample: true, + downsample_padding: 1, + } + } +} + +#[derive(Debug)] +pub struct DownBlock2D { + resnets: Vec<ResnetBlock2D>, + downsampler: Option<Downsample2D>, + pub config: DownBlock2DConfig, +} + +impl DownBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + temb_channels: Option<usize>, + config: DownBlock2DConfig, + ) -> Result<Self> { + let vs_resnets = vs.pp("resnets"); + let resnet_cfg = ResnetBlock2DConfig { + out_channels: Some(out_channels), + eps: config.resnet_eps, + output_scale_factor: config.output_scale_factor, + temb_channels, + ..Default::default() + }; + let resnets = (0..config.num_layers) + .map(|i| { + let in_channels = if i == 0 { in_channels } else { out_channels }; + ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg) + }) + .collect::<Result<Vec<_>>>()?; + let downsampler = if config.add_downsample { + let downsampler = Downsample2D::new( + vs.pp("downsamplers").pp("0"), + out_channels, + true, + out_channels, + config.downsample_padding, + )?; + Some(downsampler) + } else { + None + }; + Ok(Self { + resnets, + downsampler, + config, + }) + } + + pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> { + let mut xs = xs.clone(); + let mut output_states = vec![]; + for resnet in self.resnets.iter() { + xs = resnet.forward(&xs, temb)?; + output_states.push(xs.clone()); + } + let xs = match &self.downsampler { + Some(downsampler) => { + let xs = downsampler.forward(&xs)?; + output_states.push(xs.clone()); + xs + } + None => xs, + }; + Ok((xs, output_states)) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct CrossAttnDownBlock2DConfig { + pub downblock: DownBlock2DConfig, + pub attn_num_head_channels: usize, + pub cross_attention_dim: usize, + // attention_type: "default" + pub sliced_attention_size: Option<usize>, + pub use_linear_projection: bool, +} + +impl Default for CrossAttnDownBlock2DConfig { + fn default() -> Self { + Self { + downblock: Default::default(), + attn_num_head_channels: 1, + cross_attention_dim: 1280, + sliced_attention_size: None, + use_linear_projection: false, + } + } +} + +#[derive(Debug)] +pub struct CrossAttnDownBlock2D { + downblock: DownBlock2D, + attentions: Vec<SpatialTransformer>, + pub config: CrossAttnDownBlock2DConfig, +} + +impl CrossAttnDownBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + temb_channels: Option<usize>, + config: CrossAttnDownBlock2DConfig, + ) -> Result<Self> { + let downblock = DownBlock2D::new( + vs.clone(), + in_channels, + out_channels, + temb_channels, + config.downblock, + )?; + let n_heads = config.attn_num_head_channels; + let cfg = SpatialTransformerConfig { + depth: 1, + context_dim: Some(config.cross_attention_dim), + num_groups: config.downblock.resnet_groups, + sliced_attention_size: config.sliced_attention_size, + use_linear_projection: config.use_linear_projection, + }; + let vs_attn = vs.pp("attentions"); + let attentions = (0..config.downblock.num_layers) + .map(|i| { + SpatialTransformer::new( + vs_attn.pp(&i.to_string()), + out_channels, + n_heads, + out_channels / n_heads, + cfg, + ) + }) + .collect::<Result<Vec<_>>>()?; + Ok(Self { + downblock, + attentions, + config, + }) + } + + pub fn forward( + &self, + xs: &Tensor, + temb: Option<&Tensor>, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<(Tensor, Vec<Tensor>)> { + let mut output_states = vec![]; + let mut xs = xs.clone(); + for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) { + xs = resnet.forward(&xs, temb)?; + xs = attn.forward(&xs, encoder_hidden_states)?; + output_states.push(xs.clone()); + } + let xs = match &self.downblock.downsampler { + Some(downsampler) => { + let xs = downsampler.forward(&xs)?; + output_states.push(xs.clone()); + xs + } + None => xs, + }; + Ok((xs, output_states)) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct UpBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + // resnet_time_scale_shift: "default" + // resnet_act_fn: "swish" + pub resnet_groups: usize, + pub output_scale_factor: f64, + pub add_upsample: bool, +} + +impl Default for UpBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: 32, + output_scale_factor: 1., + add_upsample: true, + } + } +} + +#[derive(Debug)] +pub struct UpBlock2D { + pub resnets: Vec<ResnetBlock2D>, + upsampler: Option<Upsample2D>, + pub config: UpBlock2DConfig, +} + +impl UpBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + prev_output_channels: usize, + out_channels: usize, + temb_channels: Option<usize>, + config: UpBlock2DConfig, + ) -> Result<Self> { + let vs_resnets = vs.pp("resnets"); + let resnet_cfg = ResnetBlock2DConfig { + out_channels: Some(out_channels), + temb_channels, + eps: config.resnet_eps, + output_scale_factor: config.output_scale_factor, + ..Default::default() + }; + let resnets = (0..config.num_layers) + .map(|i| { + let res_skip_channels = if i == config.num_layers - 1 { + in_channels + } else { + out_channels + }; + let resnet_in_channels = if i == 0 { + prev_output_channels + } else { + out_channels + }; + let in_channels = resnet_in_channels + res_skip_channels; + ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg) + }) + .collect::<Result<Vec<_>>>()?; + let upsampler = if config.add_upsample { + let upsampler = + Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?; + Some(upsampler) + } else { + None + }; + Ok(Self { + resnets, + upsampler, + config, + }) + } + + pub fn forward( + &self, + xs: &Tensor, + res_xs: &[Tensor], + temb: Option<&Tensor>, + upsample_size: Option<(usize, usize)>, + ) -> Result<Tensor> { + let mut xs = xs.clone(); + for (index, resnet) in self.resnets.iter().enumerate() { + xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?; + xs = resnet.forward(&xs, temb)?; + } + match &self.upsampler { + Some(upsampler) => upsampler.forward(&xs, upsample_size), + None => Ok(xs), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct CrossAttnUpBlock2DConfig { + pub upblock: UpBlock2DConfig, + pub attn_num_head_channels: usize, + pub cross_attention_dim: usize, + // attention_type: "default" + pub sliced_attention_size: Option<usize>, + pub use_linear_projection: bool, +} + +impl Default for CrossAttnUpBlock2DConfig { + fn default() -> Self { + Self { + upblock: Default::default(), + attn_num_head_channels: 1, + cross_attention_dim: 1280, + sliced_attention_size: None, + use_linear_projection: false, + } + } +} + +#[derive(Debug)] +pub struct CrossAttnUpBlock2D { + pub upblock: UpBlock2D, + pub attentions: Vec<SpatialTransformer>, + pub config: CrossAttnUpBlock2DConfig, +} + +impl CrossAttnUpBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + prev_output_channels: usize, + out_channels: usize, + temb_channels: Option<usize>, + config: CrossAttnUpBlock2DConfig, + ) -> Result<Self> { + let upblock = UpBlock2D::new( + vs.clone(), + in_channels, + prev_output_channels, + out_channels, + temb_channels, + config.upblock, + )?; + let n_heads = config.attn_num_head_channels; + let cfg = SpatialTransformerConfig { + depth: 1, + context_dim: Some(config.cross_attention_dim), + num_groups: config.upblock.resnet_groups, + sliced_attention_size: config.sliced_attention_size, + use_linear_projection: config.use_linear_projection, + }; + let vs_attn = vs.pp("attentions"); + let attentions = (0..config.upblock.num_layers) + .map(|i| { + SpatialTransformer::new( + vs_attn.pp(&i.to_string()), + out_channels, + n_heads, + out_channels / n_heads, + cfg, + ) + }) + .collect::<Result<Vec<_>>>()?; + Ok(Self { + upblock, + attentions, + config, + }) + } + + pub fn forward( + &self, + xs: &Tensor, + res_xs: &[Tensor], + temb: Option<&Tensor>, + upsample_size: Option<(usize, usize)>, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<Tensor> { + let mut xs = xs.clone(); + for (index, resnet) in self.upblock.resnets.iter().enumerate() { + xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?; + xs = resnet.forward(&xs, temb)?; + xs = self.attentions[index].forward(&xs, encoder_hidden_states)?; + } + match &self.upblock.upsampler { + Some(upsampler) => upsampler.forward(&xs, upsample_size), + None => Ok(xs), + } + } +} diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs new file mode 100644 index 00000000..ef4dd956 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/utils.rs @@ -0,0 +1,31 @@ +use candle::{Device, Result, Tensor}; + +pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> { + if steps < 1 { + candle::bail!("cannot use linspace with steps {steps} <= 1") + } + let delta = (stop - start) / (steps - 1) as f64; + let vs = (0..steps) + .map(|step| start + step as f64 * delta) + .collect::<Vec<_>>(); + Tensor::from_vec(vs, steps, &Device::Cpu) +} + +/// Saves an image to disk using the image crate, this expects an input with shape +/// (c, width, height). +pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> { + let p = p.as_ref(); + let (channel, width, height) = img.dims3()?; + if channel != 3 { + candle::bail!("save_image expects an input of shape (3, width, height)") + } + let img = img.transpose(0, 1)?.t()?.flatten_all()?; + let pixels = img.to_vec1::<u8>()?; + let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> = + match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) { + Some(image) => image, + None => candle::bail!("error saving image {p:?}"), + }; + image.save(p).map_err(candle::Error::wrap)?; + Ok(()) +} diff --git a/candle-examples/examples/stable-diffusion/vae.rs b/candle-examples/examples/stable-diffusion/vae.rs new file mode 100644 index 00000000..7a10d932 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/vae.rs @@ -0,0 +1,378 @@ +#![allow(dead_code)] +//! # Variational Auto-Encoder (VAE) Models. +//! +//! Auto-encoder models compress their input to a usually smaller latent space +//! before expanding it back to its original shape. This results in the latent values +//! compressing the original information. +use crate::unet_2d_blocks::{ + DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig, + UpDecoderBlock2D, UpDecoderBlock2DConfig, +}; +use candle::{Result, Tensor}; +use candle_nn as nn; + +#[derive(Debug, Clone)] +struct EncoderConfig { + // down_block_types: DownEncoderBlock2D + block_out_channels: Vec<usize>, + layers_per_block: usize, + norm_num_groups: usize, + double_z: bool, +} + +impl Default for EncoderConfig { + fn default() -> Self { + Self { + block_out_channels: vec![64], + layers_per_block: 2, + norm_num_groups: 32, + double_z: true, + } + } +} + +#[derive(Debug)] +struct Encoder { + conv_in: nn::Conv2d, + down_blocks: Vec<DownEncoderBlock2D>, + mid_block: UNetMidBlock2D, + conv_norm_out: nn::GroupNorm, + conv_out: nn::Conv2d, + #[allow(dead_code)] + config: EncoderConfig, +} + +impl Encoder { + fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + config: EncoderConfig, + ) -> Result<Self> { + let conv_cfg = nn::Conv2dConfig { + stride: 1, + padding: 1, + }; + let conv_in = nn::conv2d( + in_channels, + config.block_out_channels[0], + 3, + conv_cfg, + vs.pp("conv_in"), + )?; + let mut down_blocks = vec![]; + let vs_down_blocks = vs.pp("down_blocks"); + for index in 0..config.block_out_channels.len() { + let out_channels = config.block_out_channels[index]; + let in_channels = if index > 0 { + config.block_out_channels[index - 1] + } else { + config.block_out_channels[0] + }; + let is_final = index + 1 == config.block_out_channels.len(); + let cfg = DownEncoderBlock2DConfig { + num_layers: config.layers_per_block, + resnet_eps: 1e-6, + resnet_groups: config.norm_num_groups, + add_downsample: !is_final, + downsample_padding: 0, + ..Default::default() + }; + let down_block = DownEncoderBlock2D::new( + vs_down_blocks.pp(&index.to_string()), + in_channels, + out_channels, + cfg, + )?; + down_blocks.push(down_block) + } + let last_block_out_channels = *config.block_out_channels.last().unwrap(); + let mid_cfg = UNetMidBlock2DConfig { + resnet_eps: 1e-6, + output_scale_factor: 1., + attn_num_head_channels: None, + resnet_groups: Some(config.norm_num_groups), + ..Default::default() + }; + let mid_block = + UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?; + let conv_norm_out = nn::group_norm( + config.norm_num_groups, + last_block_out_channels, + 1e-6, + vs.pp("conv_norm_out"), + )?; + let conv_out_channels = if config.double_z { + 2 * out_channels + } else { + out_channels + }; + let conv_cfg = nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let conv_out = nn::conv2d( + last_block_out_channels, + conv_out_channels, + 3, + conv_cfg, + vs.pp("conv_out"), + )?; + Ok(Self { + conv_in, + down_blocks, + mid_block, + conv_norm_out, + conv_out, + config, + }) + } +} + +impl Encoder { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = self.conv_in.forward(xs)?; + for down_block in self.down_blocks.iter() { + xs = down_block.forward(&xs)? + } + let xs = self.mid_block.forward(&xs, None)?; + let xs = self.conv_norm_out.forward(&xs)?; + let xs = nn::ops::silu(&xs)?; + self.conv_out.forward(&xs) + } +} + +#[derive(Debug, Clone)] +struct DecoderConfig { + // up_block_types: UpDecoderBlock2D + block_out_channels: Vec<usize>, + layers_per_block: usize, + norm_num_groups: usize, +} + +impl Default for DecoderConfig { + fn default() -> Self { + Self { + block_out_channels: vec![64], + layers_per_block: 2, + norm_num_groups: 32, + } + } +} + +#[derive(Debug)] +struct Decoder { + conv_in: nn::Conv2d, + up_blocks: Vec<UpDecoderBlock2D>, + mid_block: UNetMidBlock2D, + conv_norm_out: nn::GroupNorm, + conv_out: nn::Conv2d, + #[allow(dead_code)] + config: DecoderConfig, +} + +impl Decoder { + fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + config: DecoderConfig, + ) -> Result<Self> { + let n_block_out_channels = config.block_out_channels.len(); + let last_block_out_channels = *config.block_out_channels.last().unwrap(); + let conv_cfg = nn::Conv2dConfig { + stride: 1, + padding: 1, + }; + let conv_in = nn::conv2d( + in_channels, + last_block_out_channels, + 3, + conv_cfg, + vs.pp("conv_in"), + )?; + let mid_cfg = UNetMidBlock2DConfig { + resnet_eps: 1e-6, + output_scale_factor: 1., + attn_num_head_channels: None, + resnet_groups: Some(config.norm_num_groups), + ..Default::default() + }; + let mid_block = + UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?; + let mut up_blocks = vec![]; + let vs_up_blocks = vs.pp("up_blocks"); + let reversed_block_out_channels: Vec<_> = + config.block_out_channels.iter().copied().rev().collect(); + for index in 0..n_block_out_channels { + let out_channels = reversed_block_out_channels[index]; + let in_channels = if index > 0 { + reversed_block_out_channels[index - 1] + } else { + reversed_block_out_channels[0] + }; + let is_final = index + 1 == n_block_out_channels; + let cfg = UpDecoderBlock2DConfig { + num_layers: config.layers_per_block + 1, + resnet_eps: 1e-6, + resnet_groups: config.norm_num_groups, + add_upsample: !is_final, + ..Default::default() + }; + let up_block = UpDecoderBlock2D::new( + vs_up_blocks.pp(&index.to_string()), + in_channels, + out_channels, + cfg, + )?; + up_blocks.push(up_block) + } + let conv_norm_out = nn::group_norm( + config.norm_num_groups, + config.block_out_channels[0], + 1e-6, + vs.pp("conv_norm_out"), + )?; + let conv_cfg = nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let conv_out = nn::conv2d( + config.block_out_channels[0], + out_channels, + 3, + conv_cfg, + vs.pp("conv_out"), + )?; + Ok(Self { + conv_in, + up_blocks, + mid_block, + conv_norm_out, + conv_out, + config, + }) + } +} + +impl Decoder { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = self.mid_block.forward(&self.conv_in.forward(xs)?, None)?; + for up_block in self.up_blocks.iter() { + xs = up_block.forward(&xs)? + } + let xs = self.conv_norm_out.forward(&xs)?; + let xs = nn::ops::silu(&xs)?; + self.conv_out.forward(&xs) + } +} + +#[derive(Debug, Clone)] +pub struct AutoEncoderKLConfig { + pub block_out_channels: Vec<usize>, + pub layers_per_block: usize, + pub latent_channels: usize, + pub norm_num_groups: usize, +} + +impl Default for AutoEncoderKLConfig { + fn default() -> Self { + Self { + block_out_channels: vec![64], + layers_per_block: 1, + latent_channels: 4, + norm_num_groups: 32, + } + } +} + +pub struct DiagonalGaussianDistribution { + mean: Tensor, + std: Tensor, +} + +impl DiagonalGaussianDistribution { + pub fn new(parameters: &Tensor) -> Result<Self> { + let mut parameters = parameters.chunk(2, 1)?.into_iter(); + let mean = parameters.next().unwrap(); + let logvar = parameters.next().unwrap(); + let std = (logvar * 0.5)?.exp()?; + Ok(DiagonalGaussianDistribution { mean, std }) + } + + pub fn sample(&self) -> Result<Tensor> { + let sample = Tensor::randn(0., 1f32, self.mean.shape(), self.mean.device()); + &self.mean + &self.std * sample + } +} + +// https://github.com/huggingface/diffusers/blob/970e30606c2944e3286f56e8eb6d3dc6d1eb85f7/src/diffusers/models/vae.py#L485 +// This implementation is specific to the config used in stable-diffusion-v1-5 +// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json +#[derive(Debug)] +pub struct AutoEncoderKL { + encoder: Encoder, + decoder: Decoder, + quant_conv: nn::Conv2d, + post_quant_conv: nn::Conv2d, + pub config: AutoEncoderKLConfig, +} + +impl AutoEncoderKL { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + config: AutoEncoderKLConfig, + ) -> Result<Self> { + let latent_channels = config.latent_channels; + let encoder_cfg = EncoderConfig { + block_out_channels: config.block_out_channels.clone(), + layers_per_block: config.layers_per_block, + norm_num_groups: config.norm_num_groups, + double_z: true, + }; + let encoder = Encoder::new(vs.pp("encoder"), in_channels, latent_channels, encoder_cfg)?; + let decoder_cfg = DecoderConfig { + block_out_channels: config.block_out_channels.clone(), + layers_per_block: config.layers_per_block, + norm_num_groups: config.norm_num_groups, + }; + let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?; + let conv_cfg = Default::default(); + let quant_conv = nn::conv2d( + 2 * latent_channels, + 2 * latent_channels, + 1, + conv_cfg, + vs.pp("quant_conv"), + )?; + let post_quant_conv = nn::conv2d( + latent_channels, + latent_channels, + 1, + conv_cfg, + vs.pp("post_quant_conv"), + )?; + Ok(Self { + encoder, + decoder, + quant_conv, + post_quant_conv, + config, + }) + } + + /// Returns the distribution in the latent space. + pub fn encode(&self, xs: &Tensor) -> Result<DiagonalGaussianDistribution> { + let xs = self.encoder.forward(xs)?; + let parameters = self.quant_conv.forward(&xs)?; + DiagonalGaussianDistribution::new(¶meters) + } + + /// Takes as input some sampled values. + pub fn decode(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.post_quant_conv.forward(xs)?; + self.decoder.forward(&xs) + } +} |