summaryrefslogtreecommitdiff
path: root/candle-examples/examples
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-06 18:49:43 +0200
committerGitHub <noreply@github.com>2023-08-06 17:49:43 +0100
commitd34039e35267b3f4de83770f8da4ea31491bcec5 (patch)
tree6efc4796859a04223d2211303cc12b3a97321fcc /candle-examples/examples
parent93cfe5642f473889d1df62ccb8f1740f77523dd3 (diff)
downloadcandle-d34039e35267b3f4de83770f8da4ea31491bcec5.tar.gz
candle-d34039e35267b3f4de83770f8da4ea31491bcec5.tar.bz2
candle-d34039e35267b3f4de83770f8da4ea31491bcec5.zip
Add a stable diffusion example (#328)
* Start adding a stable-diffusion example. * Proper computation of the causal mask. * Add the chunk operation. * Work in progress: port the attention module. * Add some dummy modules for conv2d and group-norm, get the attention module to compile. * Re-enable the 2d convolution. * Add the embeddings module. * Add the resnet module. * Add the unet blocks. * Add the unet. * And add the variational auto-encoder. * Use the pad function from utils.
Diffstat (limited to 'candle-examples/examples')
-rw-r--r--candle-examples/examples/stable-diffusion/attention.rs445
-rw-r--r--candle-examples/examples/stable-diffusion/clip.rs304
-rw-r--r--candle-examples/examples/stable-diffusion/embeddings.rs65
-rw-r--r--candle-examples/examples/stable-diffusion/main.rs30
-rw-r--r--candle-examples/examples/stable-diffusion/resnet.rs129
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d.rs383
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d_blocks.rs809
-rw-r--r--candle-examples/examples/stable-diffusion/utils.rs17
-rw-r--r--candle-examples/examples/stable-diffusion/vae.rs378
9 files changed, 2560 insertions, 0 deletions
diff --git a/candle-examples/examples/stable-diffusion/attention.rs b/candle-examples/examples/stable-diffusion/attention.rs
new file mode 100644
index 00000000..83e7ef34
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/attention.rs
@@ -0,0 +1,445 @@
+#![allow(dead_code)]
+//! Attention Based Building Blocks
+use candle::{IndexOp, Result, Tensor, D};
+use candle_nn as nn;
+
+#[derive(Debug)]
+struct GeGlu {
+ proj: nn::Linear,
+}
+
+impl GeGlu {
+ fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
+ let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?;
+ Ok(Self { proj })
+ }
+}
+
+impl GeGlu {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
+ &hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
+ }
+}
+
+/// A feed-forward layer.
+#[derive(Debug)]
+struct FeedForward {
+ project_in: GeGlu,
+ linear: nn::Linear,
+}
+
+impl FeedForward {
+ // The glu parameter in the python code is unused?
+ // https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L347
+ /// Creates a new feed-forward layer based on some given input dimension, some
+ /// output dimension, and a multiplier to be used for the intermediary layer.
+ fn new(vs: nn::VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> {
+ let inner_dim = dim * mult;
+ let dim_out = dim_out.unwrap_or(dim);
+ let vs = vs.pp("net");
+ let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
+ let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?;
+ Ok(Self { project_in, linear })
+ }
+}
+
+impl FeedForward {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.project_in.forward(xs)?;
+ self.linear.forward(&xs)
+ }
+}
+
+#[derive(Debug)]
+struct CrossAttention {
+ to_q: nn::Linear,
+ to_k: nn::Linear,
+ to_v: nn::Linear,
+ to_out: nn::Linear,
+ heads: usize,
+ scale: f64,
+ slice_size: Option<usize>,
+}
+
+impl CrossAttention {
+ // Defaults should be heads = 8, dim_head = 64, context_dim = None
+ fn new(
+ vs: nn::VarBuilder,
+ query_dim: usize,
+ context_dim: Option<usize>,
+ heads: usize,
+ dim_head: usize,
+ slice_size: Option<usize>,
+ ) -> Result<Self> {
+ let inner_dim = dim_head * heads;
+ let context_dim = context_dim.unwrap_or(query_dim);
+ let scale = 1.0 / f64::sqrt(dim_head as f64);
+ let to_q = nn::linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?;
+ let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?;
+ let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?;
+ let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
+ Ok(Self {
+ to_q,
+ to_k,
+ to_v,
+ to_out,
+ heads,
+ scale,
+ slice_size,
+ })
+ }
+
+ fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
+ let (batch_size, seq_len, dim) = xs.dims3()?;
+ xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))?
+ .transpose(1, 2)?
+ .reshape((batch_size * self.heads, seq_len, dim / self.heads))
+ }
+
+ fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> {
+ let (batch_size, seq_len, dim) = xs.dims3()?;
+ xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))?
+ .transpose(1, 2)?
+ .reshape((batch_size / self.heads, seq_len, dim * self.heads))
+ }
+
+ fn sliced_attention(
+ &self,
+ query: &Tensor,
+ key: &Tensor,
+ value: &Tensor,
+ slice_size: usize,
+ ) -> Result<Tensor> {
+ let batch_size_attention = query.dim(0)?;
+ let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size);
+
+ for i in 0..batch_size_attention / slice_size {
+ let start_idx = i * slice_size;
+ let end_idx = (i + 1) * slice_size;
+
+ let xs = query
+ .i(start_idx..end_idx)?
+ .matmul(&(key.i(start_idx..end_idx)?.t()? * self.scale)?)?;
+ let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?;
+ hidden_states.push(xs)
+ }
+ let hidden_states = Tensor::stack(&hidden_states, 0)?;
+ self.reshape_batch_dim_to_heads(&hidden_states)
+ }
+
+ fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
+ let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
+ let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
+ self.reshape_batch_dim_to_heads(&xs)
+ }
+
+ fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
+ let query = self.to_q.forward(xs)?;
+ let context = context.unwrap_or(xs);
+ let key = self.to_k.forward(context)?;
+ let value = self.to_v.forward(context)?;
+ let query = self.reshape_heads_to_batch_dim(&query)?;
+ let key = self.reshape_heads_to_batch_dim(&key)?;
+ let value = self.reshape_heads_to_batch_dim(&value)?;
+ let xs = match self.slice_size {
+ None => self.attention(&query, &key, &value)?,
+ Some(slice_size) => {
+ if query.dim(0)? / slice_size <= 1 {
+ self.attention(&query, &key, &value)?
+ } else {
+ self.sliced_attention(&query, &key, &value, slice_size)?
+ }
+ }
+ };
+ self.to_out.forward(&xs)
+ }
+}
+
+/// A basic Transformer block.
+#[derive(Debug)]
+struct BasicTransformerBlock {
+ attn1: CrossAttention,
+ ff: FeedForward,
+ attn2: CrossAttention,
+ norm1: nn::LayerNorm,
+ norm2: nn::LayerNorm,
+ norm3: nn::LayerNorm,
+}
+
+impl BasicTransformerBlock {
+ fn new(
+ vs: nn::VarBuilder,
+ dim: usize,
+ n_heads: usize,
+ d_head: usize,
+ context_dim: Option<usize>,
+ sliced_attention_size: Option<usize>,
+ ) -> Result<Self> {
+ let attn1 = CrossAttention::new(
+ vs.pp("attn1"),
+ dim,
+ None,
+ n_heads,
+ d_head,
+ sliced_attention_size,
+ )?;
+ let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?;
+ let attn2 = CrossAttention::new(
+ vs.pp("attn2"),
+ dim,
+ context_dim,
+ n_heads,
+ d_head,
+ sliced_attention_size,
+ )?;
+ let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
+ let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
+ let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?;
+ Ok(Self {
+ attn1,
+ ff,
+ attn2,
+ norm1,
+ norm2,
+ norm3,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
+ let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?;
+ let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?;
+ self.ff.forward(&self.norm3.forward(&xs)?)? + xs
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct SpatialTransformerConfig {
+ pub depth: usize,
+ pub num_groups: usize,
+ pub context_dim: Option<usize>,
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for SpatialTransformerConfig {
+ fn default() -> Self {
+ Self {
+ depth: 1,
+ num_groups: 32,
+ context_dim: None,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+enum Proj {
+ Conv2d(nn::Conv2d),
+ Linear(nn::Linear),
+}
+
+// Aka Transformer2DModel
+#[derive(Debug)]
+pub struct SpatialTransformer {
+ norm: nn::GroupNorm,
+ proj_in: Proj,
+ transformer_blocks: Vec<BasicTransformerBlock>,
+ proj_out: Proj,
+ pub config: SpatialTransformerConfig,
+}
+
+impl SpatialTransformer {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ n_heads: usize,
+ d_head: usize,
+ config: SpatialTransformerConfig,
+ ) -> Result<Self> {
+ let inner_dim = n_heads * d_head;
+ let norm = nn::group_norm(config.num_groups, in_channels, 1e-6, vs.pp("norm"))?;
+ let proj_in = if config.use_linear_projection {
+ Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_in"))?)
+ } else {
+ Proj::Conv2d(nn::conv2d(
+ in_channels,
+ inner_dim,
+ 1,
+ Default::default(),
+ vs.pp("proj_in"),
+ )?)
+ };
+ let mut transformer_blocks = vec![];
+ let vs_tb = vs.pp("transformer_blocks");
+ for index in 0..config.depth {
+ let tb = BasicTransformerBlock::new(
+ vs_tb.pp(&index.to_string()),
+ inner_dim,
+ n_heads,
+ d_head,
+ config.context_dim,
+ config.sliced_attention_size,
+ )?;
+ transformer_blocks.push(tb)
+ }
+ let proj_out = if config.use_linear_projection {
+ Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_out"))?)
+ } else {
+ Proj::Conv2d(nn::conv2d(
+ inner_dim,
+ in_channels,
+ 1,
+ Default::default(),
+ vs.pp("proj_out"),
+ )?)
+ };
+ Ok(Self {
+ norm,
+ proj_in,
+ transformer_blocks,
+ proj_out,
+ config,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
+ let (batch, _channel, height, weight) = xs.dims4()?;
+ let residual = xs;
+ let xs = self.norm.forward(xs)?;
+ let (inner_dim, xs) = match &self.proj_in {
+ Proj::Conv2d(p) => {
+ let xs = p.forward(&xs)?;
+ let inner_dim = xs.dim(1)?;
+ let xs = xs
+ .transpose(1, 2)?
+ .t()?
+ .reshape((batch, height * weight, inner_dim))?;
+ (inner_dim, xs)
+ }
+ Proj::Linear(p) => {
+ let inner_dim = xs.dim(1)?;
+ let xs = xs
+ .transpose(1, 2)?
+ .t()?
+ .reshape((batch, height * weight, inner_dim))?;
+ (inner_dim, p.forward(&xs)?)
+ }
+ };
+ let mut xs = xs;
+ for block in self.transformer_blocks.iter() {
+ xs = block.forward(&xs, context)?
+ }
+ let xs = match &self.proj_out {
+ Proj::Conv2d(p) => p.forward(
+ &xs.reshape((batch, height, weight, inner_dim))?
+ .t()?
+ .transpose(1, 2)?,
+ )?,
+ Proj::Linear(p) => p
+ .forward(&xs)?
+ .reshape((batch, height, weight, inner_dim))?
+ .t()?
+ .transpose(1, 2)?,
+ };
+ xs + residual
+ }
+}
+
+/// Configuration for an attention block.
+#[derive(Debug, Clone, Copy)]
+pub struct AttentionBlockConfig {
+ pub num_head_channels: Option<usize>,
+ pub num_groups: usize,
+ pub rescale_output_factor: f64,
+ pub eps: f64,
+}
+
+impl Default for AttentionBlockConfig {
+ fn default() -> Self {
+ Self {
+ num_head_channels: None,
+ num_groups: 32,
+ rescale_output_factor: 1.,
+ eps: 1e-5,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct AttentionBlock {
+ group_norm: nn::GroupNorm,
+ query: nn::Linear,
+ key: nn::Linear,
+ value: nn::Linear,
+ proj_attn: nn::Linear,
+ channels: usize,
+ num_heads: usize,
+ config: AttentionBlockConfig,
+}
+
+impl AttentionBlock {
+ pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result<Self> {
+ let num_head_channels = config.num_head_channels.unwrap_or(channels);
+ let num_heads = channels / num_head_channels;
+ let group_norm =
+ nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
+ let query = nn::linear(channels, channels, vs.pp("query"))?;
+ let key = nn::linear(channels, channels, vs.pp("key"))?;
+ let value = nn::linear(channels, channels, vs.pp("value"))?;
+ let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?;
+ Ok(Self {
+ group_norm,
+ query,
+ key,
+ value,
+ proj_attn,
+ channels,
+ num_heads,
+ config,
+ })
+ }
+
+ fn transpose_for_scores(&self, xs: Tensor) -> Result<Tensor> {
+ let (batch, t, h_times_d) = xs.dims3()?;
+ xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
+ .transpose(1, 2)
+ }
+}
+
+impl AttentionBlock {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let residual = xs;
+ let (batch, channel, height, width) = xs.dims4()?;
+ let xs = self
+ .group_norm
+ .forward(xs)?
+ .reshape((batch, channel, height * width))?
+ .transpose(1, 2)?;
+
+ let query_proj = self.query.forward(&xs)?;
+ let key_proj = self.key.forward(&xs)?;
+ let value_proj = self.value.forward(&xs)?;
+
+ let query_states = self.transpose_for_scores(query_proj)?;
+ let key_states = self.transpose_for_scores(key_proj)?;
+ let value_states = self.transpose_for_scores(value_proj)?;
+
+ let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25);
+ let attention_scores =
+ // TODO: Check that this needs two multiplication by `scale`.
+ (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
+ let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
+
+ let xs = attention_probs.matmul(&value_states)?;
+ let xs = xs.transpose(1, 2)?.contiguous()?;
+ let xs = xs.flatten_from(D::Minus2)?;
+ let xs = self
+ .proj_attn
+ .forward(&xs)?
+ .t()?
+ .reshape((batch, channel, height, width))?;
+ (xs + residual)? / self.config.rescale_output_factor
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs
new file mode 100644
index 00000000..86313c52
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/clip.rs
@@ -0,0 +1,304 @@
+#![allow(dead_code)]
+//! Contrastive Language-Image Pre-Training
+//!
+//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
+//! pairs of images with related texts.
+//!
+//! https://github.com/openai/CLIP
+use candle::{Device, Result, Tensor, D};
+
+#[derive(Debug, Clone, Copy)]
+pub enum Activation {
+ QuickGelu,
+ Gelu,
+}
+
+impl Activation {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ match self {
+ Activation::QuickGelu => xs * crate::utils::sigmoid(&(xs * 1.702f64)?)?,
+ Activation::Gelu => xs.gelu(),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Config {
+ vocab_size: usize,
+ embed_dim: usize, // aka config.hidden_size
+ activation: Activation, // aka config.hidden_act
+ intermediate_size: usize,
+ max_position_embeddings: usize,
+ // The character to use for padding, use EOS when not set.
+ pad_with: Option<String>,
+ num_hidden_layers: usize,
+ num_attention_heads: usize,
+ #[allow(dead_code)]
+ projection_dim: usize,
+}
+
+impl Config {
+ // The config details can be found in the "text_config" section of this json file:
+ // https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
+ pub fn v1_5() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 768,
+ intermediate_size: 3072,
+ max_position_embeddings: 77,
+ pad_with: None,
+ num_hidden_layers: 12,
+ num_attention_heads: 12,
+ projection_dim: 768,
+ activation: Activation::QuickGelu,
+ }
+ }
+
+ // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/text_encoder/config.json
+ pub fn v2_1() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 1024,
+ intermediate_size: 4096,
+ max_position_embeddings: 77,
+ pad_with: Some("!".to_string()),
+ num_hidden_layers: 23,
+ num_attention_heads: 16,
+ projection_dim: 512,
+ activation: Activation::Gelu,
+ }
+ }
+}
+
+// CLIP Text Model
+// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py
+#[derive(Debug)]
+struct ClipTextEmbeddings {
+ token_embedding: candle_nn::Embedding,
+ position_embedding: candle_nn::Embedding,
+ position_ids: Tensor,
+}
+
+impl ClipTextEmbeddings {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let token_embedding =
+ candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
+ let position_embedding = candle_nn::embedding(
+ c.max_position_embeddings,
+ c.embed_dim,
+ vs.pp("position_embedding"),
+ )?;
+ let position_ids =
+ Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(1)?;
+ Ok(ClipTextEmbeddings {
+ token_embedding,
+ position_embedding,
+ position_ids,
+ })
+ }
+}
+
+impl ClipTextEmbeddings {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let token_embedding = self.token_embedding.forward(xs)?;
+ let position_embedding = self.position_embedding.forward(&self.position_ids)?;
+ token_embedding + position_embedding
+ }
+}
+
+#[derive(Debug)]
+struct ClipAttention {
+ k_proj: candle_nn::Linear,
+ v_proj: candle_nn::Linear,
+ q_proj: candle_nn::Linear,
+ out_proj: candle_nn::Linear,
+ head_dim: usize,
+ scale: f64,
+ num_attention_heads: usize,
+}
+
+impl ClipAttention {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let embed_dim = c.embed_dim;
+ let num_attention_heads = c.num_attention_heads;
+ let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
+ let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
+ let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
+ let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
+ let head_dim = embed_dim / num_attention_heads;
+ let scale = (head_dim as f64).powf(-0.5);
+ Ok(ClipAttention {
+ k_proj,
+ v_proj,
+ q_proj,
+ out_proj,
+ head_dim,
+ scale,
+ num_attention_heads,
+ })
+ }
+
+ fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
+ xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
+ .transpose(1, 2)?
+ .contiguous()
+ }
+
+ fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
+ let (bsz, seq_len, embed_dim) = xs.dims3()?;
+ let query_states = (self.q_proj.forward(xs)? * self.scale)?;
+ let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
+ let query_states = self
+ .shape(&query_states, seq_len, bsz)?
+ .reshape(proj_shape)?;
+ let key_states = self
+ .shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
+ .reshape(proj_shape)?;
+ let value_states = self
+ .shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
+ .reshape(proj_shape)?;
+ let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
+
+ let src_len = key_states.dim(1)?;
+ let attn_weights =
+ (attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
+ + causal_attention_mask)?;
+ let attn_weights =
+ attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
+ let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
+
+ let attn_output = attn_weights.matmul(&value_states)?;
+ let attn_output = attn_output
+ .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
+ .transpose(1, 2)?
+ .reshape((bsz, seq_len, embed_dim))?;
+ self.out_proj.forward(&attn_output)
+ }
+}
+
+#[derive(Debug)]
+struct ClipMlp {
+ fc1: candle_nn::Linear,
+ fc2: candle_nn::Linear,
+ activation: Activation,
+}
+
+impl ClipMlp {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let fc1 = candle_nn::linear(c.embed_dim, c.intermediate_size, vs.pp("fc1"))?;
+ let fc2 = candle_nn::linear(c.intermediate_size, c.embed_dim, vs.pp("fc2"))?;
+ Ok(ClipMlp {
+ fc1,
+ fc2,
+ activation: c.activation,
+ })
+ }
+}
+
+impl ClipMlp {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.fc1.forward(xs)?;
+ self.fc2.forward(&self.activation.forward(&xs)?)
+ }
+}
+
+#[derive(Debug)]
+struct ClipEncoderLayer {
+ self_attn: ClipAttention,
+ layer_norm1: candle_nn::LayerNorm,
+ mlp: ClipMlp,
+ layer_norm2: candle_nn::LayerNorm,
+}
+
+impl ClipEncoderLayer {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
+ let layer_norm1 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm1"))?;
+ let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
+ let layer_norm2 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm2"))?;
+ Ok(ClipEncoderLayer {
+ self_attn,
+ layer_norm1,
+ mlp,
+ layer_norm2,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
+ let residual = xs;
+ let xs = self.layer_norm1.forward(xs)?;
+ let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
+ let xs = (xs + residual)?;
+
+ let residual = &xs;
+ let xs = self.layer_norm2.forward(&xs)?;
+ let xs = self.mlp.forward(&xs)?;
+ xs + residual
+ }
+}
+
+#[derive(Debug)]
+struct ClipEncoder {
+ layers: Vec<ClipEncoderLayer>,
+}
+
+impl ClipEncoder {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let vs = vs.pp("layers");
+ let mut layers: Vec<ClipEncoderLayer> = Vec::new();
+ for index in 0..c.num_hidden_layers {
+ let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
+ layers.push(layer)
+ }
+ Ok(ClipEncoder { layers })
+ }
+
+ fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for layer in self.layers.iter() {
+ xs = layer.forward(&xs, causal_attention_mask)?
+ }
+ Ok(xs)
+ }
+}
+
+/// A CLIP transformer based model.
+#[derive(Debug)]
+pub struct ClipTextTransformer {
+ embeddings: ClipTextEmbeddings,
+ encoder: ClipEncoder,
+ final_layer_norm: candle_nn::LayerNorm,
+}
+
+impl ClipTextTransformer {
+ pub fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let vs = vs.pp("text_model");
+ let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
+ let encoder = ClipEncoder::new(vs.pp("encoder"), c)?;
+ let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
+ Ok(ClipTextTransformer {
+ embeddings,
+ encoder,
+ final_layer_norm,
+ })
+ }
+
+ // https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
+ fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
+ let mask: Vec<_> = (0..seq_len)
+ .flat_map(|i| (0..seq_len).map(move |j| u8::from(j > i)))
+ .collect();
+ let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
+ mask.broadcast_as((bsz, seq_len, seq_len))
+ }
+}
+
+impl ClipTextTransformer {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let (bsz, seq_len) = xs.dims2()?;
+ let xs = self.embeddings.forward(xs)?;
+ let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?;
+ let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
+ self.final_layer_norm.forward(&xs)
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/embeddings.rs b/candle-examples/examples/stable-diffusion/embeddings.rs
new file mode 100644
index 00000000..f8a4f351
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/embeddings.rs
@@ -0,0 +1,65 @@
+#![allow(dead_code)]
+use candle::{Result, Tensor, D};
+use candle_nn as nn;
+
+#[derive(Debug)]
+pub struct TimestepEmbedding {
+ linear_1: nn::Linear,
+ linear_2: nn::Linear,
+}
+
+impl TimestepEmbedding {
+ // act_fn: "silu"
+ pub fn new(vs: nn::VarBuilder, channel: usize, time_embed_dim: usize) -> Result<Self> {
+ let linear_1 = nn::linear(channel, time_embed_dim, vs.pp("linear_1"))?;
+ let linear_2 = nn::linear(time_embed_dim, time_embed_dim, vs.pp("linear_2"))?;
+ Ok(Self { linear_1, linear_2 })
+ }
+}
+
+impl TimestepEmbedding {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?;
+ self.linear_2.forward(&xs)
+ }
+}
+
+#[derive(Debug)]
+pub struct Timesteps {
+ num_channels: usize,
+ flip_sin_to_cos: bool,
+ downscale_freq_shift: f64,
+}
+
+impl Timesteps {
+ pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self {
+ Self {
+ num_channels,
+ flip_sin_to_cos,
+ downscale_freq_shift,
+ }
+ }
+}
+
+impl Timesteps {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let half_dim = (self.num_channels / 2) as u32;
+ let exponent =
+ (Tensor::arange(0, half_dim, xs.device())?.to_dtype(xs.dtype())? * -f64::ln(10000.))?;
+ let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
+ let emb = exponent.exp()?;
+ // emb = timesteps[:, None].float() * emb[None, :]
+ let emb = (xs.unsqueeze(D::Minus1)? * emb.unsqueeze(0)?)?;
+ let (cos, sin) = (emb.cos()?, emb.sin()?);
+ let emb = if self.flip_sin_to_cos {
+ Tensor::cat(&[&cos, &sin], D::Minus1)?
+ } else {
+ Tensor::cat(&[&sin, &cos], D::Minus1)?
+ };
+ if self.num_channels % 2 == 1 {
+ crate::utils::pad(&emb) // ([0, 1, 0, 0], 'constant', None)
+ } else {
+ Ok(emb)
+ }
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs
new file mode 100644
index 00000000..31848f38
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/main.rs
@@ -0,0 +1,30 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+mod attention;
+mod clip;
+mod embeddings;
+mod resnet;
+mod unet_2d;
+mod unet_2d_blocks;
+mod utils;
+mod vae;
+
+use anyhow::Result;
+use clap::Parser;
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ #[arg(long)]
+ prompt: String,
+}
+
+fn main() -> Result<()> {
+ let _args = Args::parse();
+ Ok(())
+}
diff --git a/candle-examples/examples/stable-diffusion/resnet.rs b/candle-examples/examples/stable-diffusion/resnet.rs
new file mode 100644
index 00000000..b6696083
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/resnet.rs
@@ -0,0 +1,129 @@
+#![allow(dead_code)]
+//! ResNet Building Blocks
+//!
+//! Some Residual Network blocks used in UNet models.
+//!
+//! Denoising Diffusion Implicit Models, K. He and al, 2015.
+//! https://arxiv.org/abs/1512.03385
+use candle::{Result, Tensor, D};
+use candle_nn as nn;
+
+/// Configuration for a ResNet block.
+#[derive(Debug, Clone, Copy)]
+pub struct ResnetBlock2DConfig {
+ /// The number of output channels, defaults to the number of input channels.
+ pub out_channels: Option<usize>,
+ pub temb_channels: Option<usize>,
+ /// The number of groups to use in group normalization.
+ pub groups: usize,
+ pub groups_out: Option<usize>,
+ /// The epsilon to be used in the group normalization operations.
+ pub eps: f64,
+ /// Whether to use a 2D convolution in the skip connection. When using None,
+ /// such a convolution is used if the number of input channels is different from
+ /// the number of output channels.
+ pub use_in_shortcut: Option<bool>,
+ // non_linearity: silu
+ /// The final output is scaled by dividing by this value.
+ pub output_scale_factor: f64,
+}
+
+impl Default for ResnetBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ out_channels: None,
+ temb_channels: Some(512),
+ groups: 32,
+ groups_out: None,
+ eps: 1e-6,
+ use_in_shortcut: None,
+ output_scale_factor: 1.,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct ResnetBlock2D {
+ norm1: nn::GroupNorm,
+ conv1: nn::Conv2d,
+ norm2: nn::GroupNorm,
+ conv2: nn::Conv2d,
+ time_emb_proj: Option<nn::Linear>,
+ conv_shortcut: Option<nn::Conv2d>,
+ config: ResnetBlock2DConfig,
+}
+
+impl ResnetBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ config: ResnetBlock2DConfig,
+ ) -> Result<Self> {
+ let out_channels = config.out_channels.unwrap_or(in_channels);
+ let conv_cfg = nn::Conv2dConfig {
+ stride: 1,
+ padding: 1,
+ };
+ let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
+ let conv1 = nn::conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
+ let groups_out = config.groups_out.unwrap_or(config.groups);
+ let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?;
+ let conv2 = nn::conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
+ let use_in_shortcut = config
+ .use_in_shortcut
+ .unwrap_or(in_channels != out_channels);
+ let conv_shortcut = if use_in_shortcut {
+ let conv_cfg = nn::Conv2dConfig {
+ stride: 1,
+ padding: 0,
+ };
+ Some(nn::conv2d(
+ in_channels,
+ out_channels,
+ 1,
+ conv_cfg,
+ vs.pp("conv_shortcut"),
+ )?)
+ } else {
+ None
+ };
+ let time_emb_proj = match config.temb_channels {
+ None => None,
+ Some(temb_channels) => Some(nn::linear(
+ temb_channels,
+ out_channels,
+ vs.pp("time_emb_proj"),
+ )?),
+ };
+ Ok(Self {
+ norm1,
+ conv1,
+ norm2,
+ conv2,
+ time_emb_proj,
+ config,
+ conv_shortcut,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
+ let shortcut_xs = match &self.conv_shortcut {
+ Some(conv_shortcut) => conv_shortcut.forward(xs)?,
+ None => xs.clone(),
+ };
+ let xs = self.norm1.forward(xs)?;
+ let xs = self.conv1.forward(&nn::ops::silu(&xs)?)?;
+ let xs = match (temb, &self.time_emb_proj) {
+ (Some(temb), Some(time_emb_proj)) => time_emb_proj
+ .forward(&nn::ops::silu(temb)?)?
+ .unsqueeze(D::Minus1)?
+ .unsqueeze(D::Minus1)?
+ .add(&xs)?,
+ _ => xs,
+ };
+ let xs = self
+ .conv2
+ .forward(&nn::ops::silu(&self.norm2.forward(&xs)?)?)?;
+ (shortcut_xs + xs)? / self.config.output_scale_factor
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs
new file mode 100644
index 00000000..8ebd1876
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/unet_2d.rs
@@ -0,0 +1,383 @@
+#![allow(dead_code)]
+//! 2D UNet Denoising Models
+//!
+//! The 2D Unet models take as input a noisy sample and the current diffusion
+//! timestep and return a denoised version of the input.
+use crate::embeddings::{TimestepEmbedding, Timesteps};
+use crate::unet_2d_blocks::*;
+use candle::{DType, Result, Tensor};
+use candle_nn as nn;
+
+#[derive(Debug, Clone, Copy)]
+pub struct BlockConfig {
+ pub out_channels: usize,
+ pub use_cross_attn: bool,
+ pub attention_head_dim: usize,
+}
+
+#[derive(Debug, Clone)]
+pub struct UNet2DConditionModelConfig {
+ pub center_input_sample: bool,
+ pub flip_sin_to_cos: bool,
+ pub freq_shift: f64,
+ pub blocks: Vec<BlockConfig>,
+ pub layers_per_block: usize,
+ pub downsample_padding: usize,
+ pub mid_block_scale_factor: f64,
+ pub norm_num_groups: usize,
+ pub norm_eps: f64,
+ pub cross_attention_dim: usize,
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for UNet2DConditionModelConfig {
+ fn default() -> Self {
+ Self {
+ center_input_sample: false,
+ flip_sin_to_cos: true,
+ freq_shift: 0.,
+ blocks: vec![
+ BlockConfig {
+ out_channels: 320,
+ use_cross_attn: true,
+ attention_head_dim: 8,
+ },
+ BlockConfig {
+ out_channels: 640,
+ use_cross_attn: true,
+ attention_head_dim: 8,
+ },
+ BlockConfig {
+ out_channels: 1280,
+ use_cross_attn: true,
+ attention_head_dim: 8,
+ },
+ BlockConfig {
+ out_channels: 1280,
+ use_cross_attn: false,
+ attention_head_dim: 8,
+ },
+ ],
+ layers_per_block: 2,
+ downsample_padding: 1,
+ mid_block_scale_factor: 1.,
+ norm_num_groups: 32,
+ norm_eps: 1e-5,
+ cross_attention_dim: 1280,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub(crate) enum UNetDownBlock {
+ Basic(DownBlock2D),
+ CrossAttn(CrossAttnDownBlock2D),
+}
+
+#[derive(Debug)]
+enum UNetUpBlock {
+ Basic(UpBlock2D),
+ CrossAttn(CrossAttnUpBlock2D),
+}
+
+#[derive(Debug)]
+pub struct UNet2DConditionModel {
+ conv_in: nn::Conv2d,
+ time_proj: Timesteps,
+ time_embedding: TimestepEmbedding,
+ down_blocks: Vec<UNetDownBlock>,
+ mid_block: UNetMidBlock2DCrossAttn,
+ up_blocks: Vec<UNetUpBlock>,
+ conv_norm_out: nn::GroupNorm,
+ conv_out: nn::Conv2d,
+ config: UNet2DConditionModelConfig,
+}
+
+impl UNet2DConditionModel {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: UNet2DConditionModelConfig,
+ ) -> Result<Self> {
+ let n_blocks = config.blocks.len();
+ let b_channels = config.blocks[0].out_channels;
+ let bl_channels = config.blocks.last().unwrap().out_channels;
+ let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim;
+ let time_embed_dim = b_channels * 4;
+ let conv_cfg = nn::Conv2dConfig {
+ stride: 1,
+ padding: 1,
+ };
+ let conv_in = nn::conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
+
+ let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift);
+ let time_embedding =
+ TimestepEmbedding::new(vs.pp("time_embedding"), b_channels, time_embed_dim)?;
+
+ let vs_db = vs.pp("down_blocks");
+ let down_blocks = (0..n_blocks)
+ .map(|i| {
+ let BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ } = config.blocks[i];
+
+ // Enable automatic attention slicing if the config sliced_attention_size is set to 0.
+ let sliced_attention_size = match config.sliced_attention_size {
+ Some(0) => Some(attention_head_dim / 2),
+ _ => config.sliced_attention_size,
+ };
+
+ let in_channels = if i > 0 {
+ config.blocks[i - 1].out_channels
+ } else {
+ b_channels
+ };
+ let db_cfg = DownBlock2DConfig {
+ num_layers: config.layers_per_block,
+ resnet_eps: config.norm_eps,
+ resnet_groups: config.norm_num_groups,
+ add_downsample: i < n_blocks - 1,
+ downsample_padding: config.downsample_padding,
+ ..Default::default()
+ };
+ if use_cross_attn {
+ let config = CrossAttnDownBlock2DConfig {
+ downblock: db_cfg,
+ attn_num_head_channels: attention_head_dim,
+ cross_attention_dim: config.cross_attention_dim,
+ sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let block = CrossAttnDownBlock2D::new(
+ vs_db.pp(&i.to_string()),
+ in_channels,
+ out_channels,
+ Some(time_embed_dim),
+ config,
+ )?;
+ Ok(UNetDownBlock::CrossAttn(block))
+ } else {
+ let block = DownBlock2D::new(
+ vs_db.pp(&i.to_string()),
+ in_channels,
+ out_channels,
+ Some(time_embed_dim),
+ db_cfg,
+ )?;
+ Ok(UNetDownBlock::Basic(block))
+ }
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ let mid_cfg = UNetMidBlock2DCrossAttnConfig {
+ resnet_eps: config.norm_eps,
+ output_scale_factor: config.mid_block_scale_factor,
+ cross_attn_dim: config.cross_attention_dim,
+ attn_num_head_channels: bl_attention_head_dim,
+ resnet_groups: Some(config.norm_num_groups),
+ use_linear_projection: config.use_linear_projection,
+ ..Default::default()
+ };
+ let mid_block = UNetMidBlock2DCrossAttn::new(
+ vs.pp("mid_block"),
+ bl_channels,
+ Some(time_embed_dim),
+ mid_cfg,
+ )?;
+
+ let vs_ub = vs.pp("up_blocks");
+ let up_blocks = (0..n_blocks)
+ .map(|i| {
+ let BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ } = config.blocks[n_blocks - 1 - i];
+
+ // Enable automatic attention slicing if the config sliced_attention_size is set to 0.
+ let sliced_attention_size = match config.sliced_attention_size {
+ Some(0) => Some(attention_head_dim / 2),
+ _ => config.sliced_attention_size,
+ };
+
+ let prev_out_channels = if i > 0 {
+ config.blocks[n_blocks - i].out_channels
+ } else {
+ bl_channels
+ };
+ let in_channels = {
+ let index = if i == n_blocks - 1 {
+ 0
+ } else {
+ n_blocks - i - 2
+ };
+ config.blocks[index].out_channels
+ };
+ let ub_cfg = UpBlock2DConfig {
+ num_layers: config.layers_per_block + 1,
+ resnet_eps: config.norm_eps,
+ resnet_groups: config.norm_num_groups,
+ add_upsample: i < n_blocks - 1,
+ ..Default::default()
+ };
+ if use_cross_attn {
+ let config = CrossAttnUpBlock2DConfig {
+ upblock: ub_cfg,
+ attn_num_head_channels: attention_head_dim,
+ cross_attention_dim: config.cross_attention_dim,
+ sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let block = CrossAttnUpBlock2D::new(
+ vs_ub.pp(&i.to_string()),
+ in_channels,
+ prev_out_channels,
+ out_channels,
+ Some(time_embed_dim),
+ config,
+ )?;
+ Ok(UNetUpBlock::CrossAttn(block))
+ } else {
+ let block = UpBlock2D::new(
+ vs_ub.pp(&i.to_string()),
+ in_channels,
+ prev_out_channels,
+ out_channels,
+ Some(time_embed_dim),
+ ub_cfg,
+ )?;
+ Ok(UNetUpBlock::Basic(block))
+ }
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ let conv_norm_out = nn::group_norm(
+ config.norm_num_groups,
+ b_channels,
+ config.norm_eps,
+ vs.pp("conv_norm_out"),
+ )?;
+ let conv_out = nn::conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?;
+ Ok(Self {
+ conv_in,
+ time_proj,
+ time_embedding,
+ down_blocks,
+ mid_block,
+ up_blocks,
+ conv_norm_out,
+ conv_out,
+ config,
+ })
+ }
+}
+
+impl UNet2DConditionModel {
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ timestep: f64,
+ encoder_hidden_states: &Tensor,
+ ) -> Result<Tensor> {
+ self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None)
+ }
+
+ pub fn forward_with_additional_residuals(
+ &self,
+ xs: &Tensor,
+ timestep: f64,
+ encoder_hidden_states: &Tensor,
+ down_block_additional_residuals: Option<&[Tensor]>,
+ mid_block_additional_residual: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let (bsize, _channels, height, width) = xs.dims4()?;
+ let device = xs.device();
+ let n_blocks = self.config.blocks.len();
+ let num_upsamplers = n_blocks - 1;
+ let default_overall_up_factor = 2usize.pow(num_upsamplers as u32);
+ let forward_upsample_size =
+ height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0;
+ // 0. center input if necessary
+ let xs = if self.config.center_input_sample {
+ ((xs * 2.0)? - 1.0)?
+ } else {
+ xs.clone()
+ };
+ // 1. time
+ let emb = (Tensor::ones(bsize, DType::F32, device)? * timestep)?;
+ let emb = self.time_proj.forward(&emb)?;
+ let emb = self.time_embedding.forward(&emb)?;
+ // 2. pre-process
+ let xs = self.conv_in.forward(&xs)?;
+ // 3. down
+ let mut down_block_res_xs = vec![xs.clone()];
+ let mut xs = xs;
+ for down_block in self.down_blocks.iter() {
+ let (_xs, res_xs) = match down_block {
+ UNetDownBlock::Basic(b) => b.forward(&xs, Some(&emb))?,
+ UNetDownBlock::CrossAttn(b) => {
+ b.forward(&xs, Some(&emb), Some(encoder_hidden_states))?
+ }
+ };
+ down_block_res_xs.extend(res_xs);
+ xs = _xs;
+ }
+
+ let new_down_block_res_xs =
+ if let Some(down_block_additional_residuals) = down_block_additional_residuals {
+ let mut v = vec![];
+ // A previous version of this code had a bug because of the addition being made
+ // in place via += hence modifying the input of the mid block.
+ for (i, residuals) in down_block_additional_residuals.iter().enumerate() {
+ v.push((&down_block_res_xs[i] + residuals)?)
+ }
+ v
+ } else {
+ down_block_res_xs
+ };
+ let mut down_block_res_xs = new_down_block_res_xs;
+
+ // 4. mid
+ let xs = self
+ .mid_block
+ .forward(&xs, Some(&emb), Some(encoder_hidden_states))?;
+ let xs = match mid_block_additional_residual {
+ None => xs,
+ Some(m) => (m + xs)?,
+ };
+ // 5. up
+ let mut xs = xs;
+ let mut upsample_size = None;
+ for (i, up_block) in self.up_blocks.iter().enumerate() {
+ let n_resnets = match up_block {
+ UNetUpBlock::Basic(b) => b.resnets.len(),
+ UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(),
+ };
+ let res_xs = down_block_res_xs.split_off(down_block_res_xs.len() - n_resnets);
+ if i < n_blocks - 1 && forward_upsample_size {
+ let (_, _, h, w) = down_block_res_xs.last().unwrap().dims4()?;
+ upsample_size = Some((h, w))
+ }
+ xs = match up_block {
+ UNetUpBlock::Basic(b) => b.forward(&xs, &res_xs, Some(&emb), upsample_size)?,
+ UNetUpBlock::CrossAttn(b) => b.forward(
+ &xs,
+ &res_xs,
+ Some(&emb),
+ upsample_size,
+ Some(encoder_hidden_states),
+ )?,
+ };
+ }
+ // 6. post-process
+ let xs = self.conv_norm_out.forward(&xs)?;
+ let xs = nn::ops::silu(&xs)?;
+ self.conv_out.forward(&xs)
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
new file mode 100644
index 00000000..8dd6cf26
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
@@ -0,0 +1,809 @@
+#![allow(dead_code)]
+//! 2D UNet Building Blocks
+//!
+use crate::attention::{
+ AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
+};
+use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
+use candle::{Result, Tensor};
+use candle_nn as nn;
+
+#[derive(Debug)]
+struct Downsample2D {
+ conv: Option<nn::Conv2d>,
+ padding: usize,
+}
+
+impl Downsample2D {
+ fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ use_conv: bool,
+ out_channels: usize,
+ padding: usize,
+ ) -> Result<Self> {
+ let conv = if use_conv {
+ let config = nn::Conv2dConfig { stride: 2, padding };
+ let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
+ Some(conv)
+ } else {
+ None
+ };
+ Ok(Downsample2D { conv, padding })
+ }
+}
+
+impl Downsample2D {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ match &self.conv {
+ None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None),
+ Some(conv) => {
+ if self.padding == 0 {
+ let xs = crate::utils::pad(xs)?; // [0, 1, 0, 1], "constant", Some(0.))?;
+ conv.forward(&xs)
+ } else {
+ conv.forward(xs)
+ }
+ }
+ }
+ }
+}
+
+// This does not support the conv-transpose mode.
+#[derive(Debug)]
+struct Upsample2D {
+ conv: nn::Conv2d,
+}
+
+impl Upsample2D {
+ fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> {
+ let config = nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
+ Ok(Self { conv })
+ }
+}
+
+impl Upsample2D {
+ fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
+ let xs = match size {
+ None => {
+ // The following does not work and it's tricky to pass no fixed
+ // dimensions so hack our way around this.
+ // xs.upsample_nearest2d(&[], Some(2.), Some(2.)
+ let (_bsize, _channels, _h, _w) = xs.dims4()?;
+ crate::utils::upsample_nearest2d(xs)? // [2 * h, 2 * w], Some(2.), Some(2.))
+ }
+ Some((_h, _w)) => crate::utils::upsample_nearest2d(xs)?, // [h, w], None, None),
+ };
+ self.conv.forward(&xs)
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct DownEncoderBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_downsample: bool,
+ pub downsample_padding: usize,
+}
+
+impl Default for DownEncoderBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_downsample: true,
+ downsample_padding: 1,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct DownEncoderBlock2D {
+ resnets: Vec<ResnetBlock2D>,
+ downsampler: Option<Downsample2D>,
+ pub config: DownEncoderBlock2DConfig,
+}
+
+impl DownEncoderBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: DownEncoderBlock2DConfig,
+ ) -> Result<Self> {
+ let resnets: Vec<_> = {
+ let vs = vs.pp("resnets");
+ let conv_cfg = ResnetBlock2DConfig {
+ eps: config.resnet_eps,
+ out_channels: Some(out_channels),
+ groups: config.resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels: None,
+ ..Default::default()
+ };
+ (0..(config.num_layers))
+ .map(|i| {
+ let in_channels = if i == 0 { in_channels } else { out_channels };
+ ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?
+ };
+ let downsampler = if config.add_downsample {
+ let downsample = Downsample2D::new(
+ vs.pp("downsamplers").pp("0"),
+ out_channels,
+ true,
+ out_channels,
+ config.downsample_padding,
+ )?;
+ Some(downsample)
+ } else {
+ None
+ };
+ Ok(Self {
+ resnets,
+ downsampler,
+ config,
+ })
+ }
+}
+
+impl DownEncoderBlock2D {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for resnet in self.resnets.iter() {
+ xs = resnet.forward(&xs, None)?
+ }
+ match &self.downsampler {
+ Some(downsampler) => downsampler.forward(&xs),
+ None => Ok(xs),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UpDecoderBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_upsample: bool,
+}
+
+impl Default for UpDecoderBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_upsample: true,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UpDecoderBlock2D {
+ resnets: Vec<ResnetBlock2D>,
+ upsampler: Option<Upsample2D>,
+ pub config: UpDecoderBlock2DConfig,
+}
+
+impl UpDecoderBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: UpDecoderBlock2DConfig,
+ ) -> Result<Self> {
+ let resnets: Vec<_> = {
+ let vs = vs.pp("resnets");
+ let conv_cfg = ResnetBlock2DConfig {
+ out_channels: Some(out_channels),
+ eps: config.resnet_eps,
+ groups: config.resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels: None,
+ ..Default::default()
+ };
+ (0..(config.num_layers))
+ .map(|i| {
+ let in_channels = if i == 0 { in_channels } else { out_channels };
+ ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?
+ };
+ let upsampler = if config.add_upsample {
+ let upsample =
+ Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
+ Some(upsample)
+ } else {
+ None
+ };
+ Ok(Self {
+ resnets,
+ upsampler,
+ config,
+ })
+ }
+}
+
+impl UpDecoderBlock2D {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for resnet in self.resnets.iter() {
+ xs = resnet.forward(&xs, None)?
+ }
+ match &self.upsampler {
+ Some(upsampler) => upsampler.forward(&xs, None),
+ None => Ok(xs),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UNetMidBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: Option<usize>,
+ pub attn_num_head_channels: Option<usize>,
+ // attention_type "default"
+ pub output_scale_factor: f64,
+}
+
+impl Default for UNetMidBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: Some(32),
+ attn_num_head_channels: Some(1),
+ output_scale_factor: 1.,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UNetMidBlock2D {
+ resnet: ResnetBlock2D,
+ attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
+ pub config: UNetMidBlock2DConfig,
+}
+
+impl UNetMidBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ temb_channels: Option<usize>,
+ config: UNetMidBlock2DConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let vs_attns = vs.pp("attentions");
+ let resnet_groups = config
+ .resnet_groups
+ .unwrap_or_else(|| usize::min(in_channels / 4, 32));
+ let resnet_cfg = ResnetBlock2DConfig {
+ eps: config.resnet_eps,
+ groups: resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels,
+ ..Default::default()
+ };
+ let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
+ let attn_cfg = AttentionBlockConfig {
+ num_head_channels: config.attn_num_head_channels,
+ num_groups: resnet_groups,
+ rescale_output_factor: config.output_scale_factor,
+ eps: config.resnet_eps,
+ };
+ let mut attn_resnets = vec![];
+ for index in 0..config.num_layers {
+ let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?;
+ let resnet = ResnetBlock2D::new(
+ vs_resnets.pp(&(index + 1).to_string()),
+ in_channels,
+ resnet_cfg,
+ )?;
+ attn_resnets.push((attn, resnet))
+ }
+ Ok(Self {
+ resnet,
+ attn_resnets,
+ config,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
+ let mut xs = self.resnet.forward(xs, temb)?;
+ for (attn, resnet) in self.attn_resnets.iter() {
+ xs = resnet.forward(&attn.forward(&xs)?, temb)?
+ }
+ Ok(xs)
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UNetMidBlock2DCrossAttnConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: Option<usize>,
+ pub attn_num_head_channels: usize,
+ // attention_type "default"
+ pub output_scale_factor: f64,
+ pub cross_attn_dim: usize,
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for UNetMidBlock2DCrossAttnConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: Some(32),
+ attn_num_head_channels: 1,
+ output_scale_factor: 1.,
+ cross_attn_dim: 1280,
+ sliced_attention_size: None, // Sliced attention disabled
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UNetMidBlock2DCrossAttn {
+ resnet: ResnetBlock2D,
+ attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
+ pub config: UNetMidBlock2DCrossAttnConfig,
+}
+
+impl UNetMidBlock2DCrossAttn {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ temb_channels: Option<usize>,
+ config: UNetMidBlock2DCrossAttnConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let vs_attns = vs.pp("attentions");
+ let resnet_groups = config
+ .resnet_groups
+ .unwrap_or_else(|| usize::min(in_channels / 4, 32));
+ let resnet_cfg = ResnetBlock2DConfig {
+ eps: config.resnet_eps,
+ groups: resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels,
+ ..Default::default()
+ };
+ let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
+ let n_heads = config.attn_num_head_channels;
+ let attn_cfg = SpatialTransformerConfig {
+ depth: 1,
+ num_groups: resnet_groups,
+ context_dim: Some(config.cross_attn_dim),
+ sliced_attention_size: config.sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let mut attn_resnets = vec![];
+ for index in 0..config.num_layers {
+ let attn = SpatialTransformer::new(
+ vs_attns.pp(&index.to_string()),
+ in_channels,
+ n_heads,
+ in_channels / n_heads,
+ attn_cfg,
+ )?;
+ let resnet = ResnetBlock2D::new(
+ vs_resnets.pp(&(index + 1).to_string()),
+ in_channels,
+ resnet_cfg,
+ )?;
+ attn_resnets.push((attn, resnet))
+ }
+ Ok(Self {
+ resnet,
+ attn_resnets,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ temb: Option<&Tensor>,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let mut xs = self.resnet.forward(xs, temb)?;
+ for (attn, resnet) in self.attn_resnets.iter() {
+ xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
+ }
+ Ok(xs)
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct DownBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ // resnet_time_scale_shift: "default"
+ // resnet_act_fn: "swish"
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_downsample: bool,
+ pub downsample_padding: usize,
+}
+
+impl Default for DownBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_downsample: true,
+ downsample_padding: 1,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct DownBlock2D {
+ resnets: Vec<ResnetBlock2D>,
+ downsampler: Option<Downsample2D>,
+ pub config: DownBlock2DConfig,
+}
+
+impl DownBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ config: DownBlock2DConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let resnet_cfg = ResnetBlock2DConfig {
+ out_channels: Some(out_channels),
+ eps: config.resnet_eps,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels,
+ ..Default::default()
+ };
+ let resnets = (0..config.num_layers)
+ .map(|i| {
+ let in_channels = if i == 0 { in_channels } else { out_channels };
+ ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let downsampler = if config.add_downsample {
+ let downsampler = Downsample2D::new(
+ vs.pp("downsamplers").pp("0"),
+ out_channels,
+ true,
+ out_channels,
+ config.downsample_padding,
+ )?;
+ Some(downsampler)
+ } else {
+ None
+ };
+ Ok(Self {
+ resnets,
+ downsampler,
+ config,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
+ let mut xs = xs.clone();
+ let mut output_states = vec![];
+ for resnet in self.resnets.iter() {
+ xs = resnet.forward(&xs, temb)?;
+ output_states.push(xs.clone());
+ }
+ let xs = match &self.downsampler {
+ Some(downsampler) => {
+ let xs = downsampler.forward(&xs)?;
+ output_states.push(xs.clone());
+ xs
+ }
+ None => xs,
+ };
+ Ok((xs, output_states))
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct CrossAttnDownBlock2DConfig {
+ pub downblock: DownBlock2DConfig,
+ pub attn_num_head_channels: usize,
+ pub cross_attention_dim: usize,
+ // attention_type: "default"
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for CrossAttnDownBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ downblock: Default::default(),
+ attn_num_head_channels: 1,
+ cross_attention_dim: 1280,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct CrossAttnDownBlock2D {
+ downblock: DownBlock2D,
+ attentions: Vec<SpatialTransformer>,
+ pub config: CrossAttnDownBlock2DConfig,
+}
+
+impl CrossAttnDownBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ config: CrossAttnDownBlock2DConfig,
+ ) -> Result<Self> {
+ let downblock = DownBlock2D::new(
+ vs.clone(),
+ in_channels,
+ out_channels,
+ temb_channels,
+ config.downblock,
+ )?;
+ let n_heads = config.attn_num_head_channels;
+ let cfg = SpatialTransformerConfig {
+ depth: 1,
+ context_dim: Some(config.cross_attention_dim),
+ num_groups: config.downblock.resnet_groups,
+ sliced_attention_size: config.sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let vs_attn = vs.pp("attentions");
+ let attentions = (0..config.downblock.num_layers)
+ .map(|i| {
+ SpatialTransformer::new(
+ vs_attn.pp(&i.to_string()),
+ out_channels,
+ n_heads,
+ out_channels / n_heads,
+ cfg,
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Self {
+ downblock,
+ attentions,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ temb: Option<&Tensor>,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<(Tensor, Vec<Tensor>)> {
+ let mut output_states = vec![];
+ let mut xs = xs.clone();
+ for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
+ xs = resnet.forward(&xs, temb)?;
+ xs = attn.forward(&xs, encoder_hidden_states)?;
+ output_states.push(xs.clone());
+ }
+ let xs = match &self.downblock.downsampler {
+ Some(downsampler) => {
+ let xs = downsampler.forward(&xs)?;
+ output_states.push(xs.clone());
+ xs
+ }
+ None => xs,
+ };
+ Ok((xs, output_states))
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UpBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ // resnet_time_scale_shift: "default"
+ // resnet_act_fn: "swish"
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_upsample: bool,
+}
+
+impl Default for UpBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_upsample: true,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UpBlock2D {
+ pub resnets: Vec<ResnetBlock2D>,
+ upsampler: Option<Upsample2D>,
+ pub config: UpBlock2DConfig,
+}
+
+impl UpBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ prev_output_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ config: UpBlock2DConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let resnet_cfg = ResnetBlock2DConfig {
+ out_channels: Some(out_channels),
+ temb_channels,
+ eps: config.resnet_eps,
+ output_scale_factor: config.output_scale_factor,
+ ..Default::default()
+ };
+ let resnets = (0..config.num_layers)
+ .map(|i| {
+ let res_skip_channels = if i == config.num_layers - 1 {
+ in_channels
+ } else {
+ out_channels
+ };
+ let resnet_in_channels = if i == 0 {
+ prev_output_channels
+ } else {
+ out_channels
+ };
+ let in_channels = resnet_in_channels + res_skip_channels;
+ ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let upsampler = if config.add_upsample {
+ let upsampler =
+ Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
+ Some(upsampler)
+ } else {
+ None
+ };
+ Ok(Self {
+ resnets,
+ upsampler,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ res_xs: &[Tensor],
+ temb: Option<&Tensor>,
+ upsample_size: Option<(usize, usize)>,
+ ) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for (index, resnet) in self.resnets.iter().enumerate() {
+ xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
+ xs = resnet.forward(&xs, temb)?;
+ }
+ match &self.upsampler {
+ Some(upsampler) => upsampler.forward(&xs, upsample_size),
+ None => Ok(xs),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct CrossAttnUpBlock2DConfig {
+ pub upblock: UpBlock2DConfig,
+ pub attn_num_head_channels: usize,
+ pub cross_attention_dim: usize,
+ // attention_type: "default"
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for CrossAttnUpBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ upblock: Default::default(),
+ attn_num_head_channels: 1,
+ cross_attention_dim: 1280,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct CrossAttnUpBlock2D {
+ pub upblock: UpBlock2D,
+ pub attentions: Vec<SpatialTransformer>,
+ pub config: CrossAttnUpBlock2DConfig,
+}
+
+impl CrossAttnUpBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ prev_output_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ config: CrossAttnUpBlock2DConfig,
+ ) -> Result<Self> {
+ let upblock = UpBlock2D::new(
+ vs.clone(),
+ in_channels,
+ prev_output_channels,
+ out_channels,
+ temb_channels,
+ config.upblock,
+ )?;
+ let n_heads = config.attn_num_head_channels;
+ let cfg = SpatialTransformerConfig {
+ depth: 1,
+ context_dim: Some(config.cross_attention_dim),
+ num_groups: config.upblock.resnet_groups,
+ sliced_attention_size: config.sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let vs_attn = vs.pp("attentions");
+ let attentions = (0..config.upblock.num_layers)
+ .map(|i| {
+ SpatialTransformer::new(
+ vs_attn.pp(&i.to_string()),
+ out_channels,
+ n_heads,
+ out_channels / n_heads,
+ cfg,
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Self {
+ upblock,
+ attentions,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ res_xs: &[Tensor],
+ temb: Option<&Tensor>,
+ upsample_size: Option<(usize, usize)>,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for (index, resnet) in self.upblock.resnets.iter().enumerate() {
+ xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
+ xs = resnet.forward(&xs, temb)?;
+ xs = self.attentions[index].forward(&xs, encoder_hidden_states)?;
+ }
+ match &self.upblock.upsampler {
+ Some(upsampler) => upsampler.forward(&xs, upsample_size),
+ None => Ok(xs),
+ }
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs
new file mode 100644
index 00000000..aa49e560
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/utils.rs
@@ -0,0 +1,17 @@
+use candle::{Result, Tensor};
+
+pub fn sigmoid(_: &Tensor) -> Result<Tensor> {
+ todo!()
+}
+
+pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> {
+ todo!()
+}
+
+pub fn pad(_: &Tensor) -> Result<Tensor> {
+ todo!()
+}
+
+pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> {
+ todo!()
+}
diff --git a/candle-examples/examples/stable-diffusion/vae.rs b/candle-examples/examples/stable-diffusion/vae.rs
new file mode 100644
index 00000000..7a10d932
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/vae.rs
@@ -0,0 +1,378 @@
+#![allow(dead_code)]
+//! # Variational Auto-Encoder (VAE) Models.
+//!
+//! Auto-encoder models compress their input to a usually smaller latent space
+//! before expanding it back to its original shape. This results in the latent values
+//! compressing the original information.
+use crate::unet_2d_blocks::{
+ DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig,
+ UpDecoderBlock2D, UpDecoderBlock2DConfig,
+};
+use candle::{Result, Tensor};
+use candle_nn as nn;
+
+#[derive(Debug, Clone)]
+struct EncoderConfig {
+ // down_block_types: DownEncoderBlock2D
+ block_out_channels: Vec<usize>,
+ layers_per_block: usize,
+ norm_num_groups: usize,
+ double_z: bool,
+}
+
+impl Default for EncoderConfig {
+ fn default() -> Self {
+ Self {
+ block_out_channels: vec![64],
+ layers_per_block: 2,
+ norm_num_groups: 32,
+ double_z: true,
+ }
+ }
+}
+
+#[derive(Debug)]
+struct Encoder {
+ conv_in: nn::Conv2d,
+ down_blocks: Vec<DownEncoderBlock2D>,
+ mid_block: UNetMidBlock2D,
+ conv_norm_out: nn::GroupNorm,
+ conv_out: nn::Conv2d,
+ #[allow(dead_code)]
+ config: EncoderConfig,
+}
+
+impl Encoder {
+ fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: EncoderConfig,
+ ) -> Result<Self> {
+ let conv_cfg = nn::Conv2dConfig {
+ stride: 1,
+ padding: 1,
+ };
+ let conv_in = nn::conv2d(
+ in_channels,
+ config.block_out_channels[0],
+ 3,
+ conv_cfg,
+ vs.pp("conv_in"),
+ )?;
+ let mut down_blocks = vec![];
+ let vs_down_blocks = vs.pp("down_blocks");
+ for index in 0..config.block_out_channels.len() {
+ let out_channels = config.block_out_channels[index];
+ let in_channels = if index > 0 {
+ config.block_out_channels[index - 1]
+ } else {
+ config.block_out_channels[0]
+ };
+ let is_final = index + 1 == config.block_out_channels.len();
+ let cfg = DownEncoderBlock2DConfig {
+ num_layers: config.layers_per_block,
+ resnet_eps: 1e-6,
+ resnet_groups: config.norm_num_groups,
+ add_downsample: !is_final,
+ downsample_padding: 0,
+ ..Default::default()
+ };
+ let down_block = DownEncoderBlock2D::new(
+ vs_down_blocks.pp(&index.to_string()),
+ in_channels,
+ out_channels,
+ cfg,
+ )?;
+ down_blocks.push(down_block)
+ }
+ let last_block_out_channels = *config.block_out_channels.last().unwrap();
+ let mid_cfg = UNetMidBlock2DConfig {
+ resnet_eps: 1e-6,
+ output_scale_factor: 1.,
+ attn_num_head_channels: None,
+ resnet_groups: Some(config.norm_num_groups),
+ ..Default::default()
+ };
+ let mid_block =
+ UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
+ let conv_norm_out = nn::group_norm(
+ config.norm_num_groups,
+ last_block_out_channels,
+ 1e-6,
+ vs.pp("conv_norm_out"),
+ )?;
+ let conv_out_channels = if config.double_z {
+ 2 * out_channels
+ } else {
+ out_channels
+ };
+ let conv_cfg = nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let conv_out = nn::conv2d(
+ last_block_out_channels,
+ conv_out_channels,
+ 3,
+ conv_cfg,
+ vs.pp("conv_out"),
+ )?;
+ Ok(Self {
+ conv_in,
+ down_blocks,
+ mid_block,
+ conv_norm_out,
+ conv_out,
+ config,
+ })
+ }
+}
+
+impl Encoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = self.conv_in.forward(xs)?;
+ for down_block in self.down_blocks.iter() {
+ xs = down_block.forward(&xs)?
+ }
+ let xs = self.mid_block.forward(&xs, None)?;
+ let xs = self.conv_norm_out.forward(&xs)?;
+ let xs = nn::ops::silu(&xs)?;
+ self.conv_out.forward(&xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct DecoderConfig {
+ // up_block_types: UpDecoderBlock2D
+ block_out_channels: Vec<usize>,
+ layers_per_block: usize,
+ norm_num_groups: usize,
+}
+
+impl Default for DecoderConfig {
+ fn default() -> Self {
+ Self {
+ block_out_channels: vec![64],
+ layers_per_block: 2,
+ norm_num_groups: 32,
+ }
+ }
+}
+
+#[derive(Debug)]
+struct Decoder {
+ conv_in: nn::Conv2d,
+ up_blocks: Vec<UpDecoderBlock2D>,
+ mid_block: UNetMidBlock2D,
+ conv_norm_out: nn::GroupNorm,
+ conv_out: nn::Conv2d,
+ #[allow(dead_code)]
+ config: DecoderConfig,
+}
+
+impl Decoder {
+ fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: DecoderConfig,
+ ) -> Result<Self> {
+ let n_block_out_channels = config.block_out_channels.len();
+ let last_block_out_channels = *config.block_out_channels.last().unwrap();
+ let conv_cfg = nn::Conv2dConfig {
+ stride: 1,
+ padding: 1,
+ };
+ let conv_in = nn::conv2d(
+ in_channels,
+ last_block_out_channels,
+ 3,
+ conv_cfg,
+ vs.pp("conv_in"),
+ )?;
+ let mid_cfg = UNetMidBlock2DConfig {
+ resnet_eps: 1e-6,
+ output_scale_factor: 1.,
+ attn_num_head_channels: None,
+ resnet_groups: Some(config.norm_num_groups),
+ ..Default::default()
+ };
+ let mid_block =
+ UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
+ let mut up_blocks = vec![];
+ let vs_up_blocks = vs.pp("up_blocks");
+ let reversed_block_out_channels: Vec<_> =
+ config.block_out_channels.iter().copied().rev().collect();
+ for index in 0..n_block_out_channels {
+ let out_channels = reversed_block_out_channels[index];
+ let in_channels = if index > 0 {
+ reversed_block_out_channels[index - 1]
+ } else {
+ reversed_block_out_channels[0]
+ };
+ let is_final = index + 1 == n_block_out_channels;
+ let cfg = UpDecoderBlock2DConfig {
+ num_layers: config.layers_per_block + 1,
+ resnet_eps: 1e-6,
+ resnet_groups: config.norm_num_groups,
+ add_upsample: !is_final,
+ ..Default::default()
+ };
+ let up_block = UpDecoderBlock2D::new(
+ vs_up_blocks.pp(&index.to_string()),
+ in_channels,
+ out_channels,
+ cfg,
+ )?;
+ up_blocks.push(up_block)
+ }
+ let conv_norm_out = nn::group_norm(
+ config.norm_num_groups,
+ config.block_out_channels[0],
+ 1e-6,
+ vs.pp("conv_norm_out"),
+ )?;
+ let conv_cfg = nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let conv_out = nn::conv2d(
+ config.block_out_channels[0],
+ out_channels,
+ 3,
+ conv_cfg,
+ vs.pp("conv_out"),
+ )?;
+ Ok(Self {
+ conv_in,
+ up_blocks,
+ mid_block,
+ conv_norm_out,
+ conv_out,
+ config,
+ })
+ }
+}
+
+impl Decoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = self.mid_block.forward(&self.conv_in.forward(xs)?, None)?;
+ for up_block in self.up_blocks.iter() {
+ xs = up_block.forward(&xs)?
+ }
+ let xs = self.conv_norm_out.forward(&xs)?;
+ let xs = nn::ops::silu(&xs)?;
+ self.conv_out.forward(&xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct AutoEncoderKLConfig {
+ pub block_out_channels: Vec<usize>,
+ pub layers_per_block: usize,
+ pub latent_channels: usize,
+ pub norm_num_groups: usize,
+}
+
+impl Default for AutoEncoderKLConfig {
+ fn default() -> Self {
+ Self {
+ block_out_channels: vec![64],
+ layers_per_block: 1,
+ latent_channels: 4,
+ norm_num_groups: 32,
+ }
+ }
+}
+
+pub struct DiagonalGaussianDistribution {
+ mean: Tensor,
+ std: Tensor,
+}
+
+impl DiagonalGaussianDistribution {
+ pub fn new(parameters: &Tensor) -> Result<Self> {
+ let mut parameters = parameters.chunk(2, 1)?.into_iter();
+ let mean = parameters.next().unwrap();
+ let logvar = parameters.next().unwrap();
+ let std = (logvar * 0.5)?.exp()?;
+ Ok(DiagonalGaussianDistribution { mean, std })
+ }
+
+ pub fn sample(&self) -> Result<Tensor> {
+ let sample = Tensor::randn(0., 1f32, self.mean.shape(), self.mean.device());
+ &self.mean + &self.std * sample
+ }
+}
+
+// https://github.com/huggingface/diffusers/blob/970e30606c2944e3286f56e8eb6d3dc6d1eb85f7/src/diffusers/models/vae.py#L485
+// This implementation is specific to the config used in stable-diffusion-v1-5
+// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
+#[derive(Debug)]
+pub struct AutoEncoderKL {
+ encoder: Encoder,
+ decoder: Decoder,
+ quant_conv: nn::Conv2d,
+ post_quant_conv: nn::Conv2d,
+ pub config: AutoEncoderKLConfig,
+}
+
+impl AutoEncoderKL {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: AutoEncoderKLConfig,
+ ) -> Result<Self> {
+ let latent_channels = config.latent_channels;
+ let encoder_cfg = EncoderConfig {
+ block_out_channels: config.block_out_channels.clone(),
+ layers_per_block: config.layers_per_block,
+ norm_num_groups: config.norm_num_groups,
+ double_z: true,
+ };
+ let encoder = Encoder::new(vs.pp("encoder"), in_channels, latent_channels, encoder_cfg)?;
+ let decoder_cfg = DecoderConfig {
+ block_out_channels: config.block_out_channels.clone(),
+ layers_per_block: config.layers_per_block,
+ norm_num_groups: config.norm_num_groups,
+ };
+ let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?;
+ let conv_cfg = Default::default();
+ let quant_conv = nn::conv2d(
+ 2 * latent_channels,
+ 2 * latent_channels,
+ 1,
+ conv_cfg,
+ vs.pp("quant_conv"),
+ )?;
+ let post_quant_conv = nn::conv2d(
+ latent_channels,
+ latent_channels,
+ 1,
+ conv_cfg,
+ vs.pp("post_quant_conv"),
+ )?;
+ Ok(Self {
+ encoder,
+ decoder,
+ quant_conv,
+ post_quant_conv,
+ config,
+ })
+ }
+
+ /// Returns the distribution in the latent space.
+ pub fn encode(&self, xs: &Tensor) -> Result<DiagonalGaussianDistribution> {
+ let xs = self.encoder.forward(xs)?;
+ let parameters = self.quant_conv.forward(&xs)?;
+ DiagonalGaussianDistribution::new(&parameters)
+ }
+
+ /// Takes as input some sampled values.
+ pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.post_quant_conv.forward(xs)?;
+ self.decoder.forward(&xs)
+ }
+}