summaryrefslogtreecommitdiff
path: root/candle-examples
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples')
-rw-r--r--candle-examples/Cargo.toml8
-rw-r--r--candle-examples/examples/llama/main.rs17
-rw-r--r--candle-examples/examples/llama/model.rs43
-rw-r--r--candle-examples/examples/llama2-c/main.rs28
-rw-r--r--candle-examples/examples/llama2-c/training.rs124
-rw-r--r--candle-examples/examples/llama2-c/weights.rs25
-rw-r--r--candle-examples/examples/mnist-training/main.rs4
-rw-r--r--candle-examples/examples/stable-diffusion/attention.rs445
-rw-r--r--candle-examples/examples/stable-diffusion/clip.rs305
-rw-r--r--candle-examples/examples/stable-diffusion/ddim.rs181
-rw-r--r--candle-examples/examples/stable-diffusion/embeddings.rs65
-rw-r--r--candle-examples/examples/stable-diffusion/main.rs273
-rw-r--r--candle-examples/examples/stable-diffusion/resnet.rs129
-rw-r--r--candle-examples/examples/stable-diffusion/schedulers.rs45
-rw-r--r--candle-examples/examples/stable-diffusion/stable_diffusion.rs212
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d.rs383
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d_blocks.rs808
-rw-r--r--candle-examples/examples/stable-diffusion/utils.rs31
-rw-r--r--candle-examples/examples/stable-diffusion/vae.rs378
19 files changed, 3364 insertions, 140 deletions
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index c4e34656..54eb0be6 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -10,7 +10,9 @@ license.workspace = true
readme = "README.md"
[dependencies]
+accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.1.0", package = "candle-core" }
+candle-datasets = { path = "../candle-datasets", version = "0.1.0" }
candle-nn = { path = "../candle-nn", version = "0.1.0" }
candle-transformers = { path = "../candle-transformers", version = "0.1.0" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.0", optional = true }
@@ -21,6 +23,7 @@ num-traits = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
+image = { workspace = true, optional = true }
[dev-dependencies]
anyhow = { workspace = true }
@@ -42,6 +45,7 @@ anyhow = { workspace = true }
[features]
default = []
+accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
@@ -50,3 +54,7 @@ nccl = ["cuda", "cudarc/nccl", "dep:half"]
[[example]]
name = "llama_multiprocess"
required-features = ["cuda", "nccl", "flash-attn"]
+
+[[example]]
+name = "stable-diffusion"
+required-features = ["image"]
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index f3cf17bc..b2c4e55a 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -9,6 +9,9 @@
// In order to convert the llama weights to a .npz file, run:
// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
@@ -111,6 +114,10 @@ struct Args {
#[arg(long)]
use_f32: bool,
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
#[arg(long)]
model_id: Option<String>,
@@ -123,8 +130,18 @@ struct Args {
fn main() -> Result<()> {
use tokenizers::Tokenizer;
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
let args = Args::parse();
+ let _guard = if args.tracing {
+ println!("tracing...");
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
let device = candle_examples::device(args.cpu)?;
let config = if args.v1 {
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index ae27afc1..f5ac587e 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
-use candle_nn::{Embedding, Linear, VarBuilder};
+use candle_nn::{Embedding, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@@ -47,6 +47,21 @@ impl Config {
}
}
+// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
+// model.
+#[derive(Debug)]
+pub struct Linear {
+ inner: candle_nn::Linear,
+ span: tracing::Span,
+}
+
+impl Linear {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
+
#[derive(Clone)]
pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
@@ -106,8 +121,9 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
- let weight = vb.get((size2, size1), "weight")?;
- Ok(Linear::new(weight, None))
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
+ Ok(Linear { inner, span })
}
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
@@ -118,15 +134,18 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
struct RmsNorm {
scale: Tensor,
eps: f64,
+ span: tracing::Span,
}
impl RmsNorm {
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let scale = vb.get(size, "weight")?;
- Ok(Self { scale, eps })
+ Ok(Self { scale, eps, span })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let in_dtype = x.dtype();
// This is a no-op if x's dtype is already f32.
let x = x.to_dtype(DType::F32)?;
@@ -155,6 +174,8 @@ struct CausalSelfAttention {
head_dim: usize,
cache: Cache,
use_flash_attn: bool,
+ span: tracing::Span,
+ span_rot: tracing::Span,
}
#[cfg(feature = "flash-attn")]
@@ -175,6 +196,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
+ let _enter = self.span_rot.enter();
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
@@ -188,6 +210,7 @@ impl CausalSelfAttention {
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
+ let _enter = self.span.enter();
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
@@ -269,6 +292,8 @@ impl CausalSelfAttention {
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "attn");
+ let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let size_in = cfg.hidden_size;
let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head;
@@ -286,6 +311,8 @@ impl CausalSelfAttention {
head_dim: cfg.hidden_size / cfg.n_head,
cache: cache.clone(),
use_flash_attn: cfg.use_flash_attn,
+ span,
+ span_rot,
})
}
}
@@ -301,15 +328,18 @@ struct Mlp {
c_fc1: Linear,
c_fc2: Linear,
c_proj: Linear,
+ span: tracing::Span,
}
impl Mlp {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "mlp");
let h_size = cfg.hidden_size;
let i_size = cfg.intermediate_size;
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
@@ -319,6 +349,7 @@ impl Mlp {
c_fc1,
c_fc2,
c_proj,
+ span,
})
}
}
@@ -328,10 +359,12 @@ struct Block {
attn: CausalSelfAttention,
rms_2: RmsNorm,
mlp: Mlp,
+ span: tracing::Span,
}
impl Block {
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
+ let _enter = self.span.enter();
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
@@ -341,6 +374,7 @@ impl Block {
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "block");
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
@@ -354,6 +388,7 @@ impl Block {
attn,
rms_2,
mlp,
+ span,
})
}
}
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index 8b64fdd2..418218b6 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -1,5 +1,8 @@
// https://github.com/karpathy/llama2.c
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
@@ -27,7 +30,7 @@ struct InferenceCmd {
#[arg(long, default_value = "")]
prompt: String,
- /// Config file in binary format.
+ /// Config file in binary or safetensors format.
#[arg(long)]
config: Option<String>,
@@ -200,7 +203,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets))))
}
});
- let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
+ let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
for inp_tgt in batch_iter {
let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0)?;
@@ -225,11 +228,22 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let device = candle_examples::device(common_args.cpu)?;
- let mut file = std::fs::File::open(config_path)?;
- let config = Config::from_reader(&mut file)?;
- println!("{config:?}");
- let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
- let vb = weights.var_builder(&config, &device)?;
+ let is_safetensors = config_path
+ .extension()
+ .map_or(false, |v| v == "safetensors");
+ let (vb, config) = if is_safetensors {
+ let config = Config::tiny();
+ let tensors = candle::safetensors::load(config_path, &device)?;
+ let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
+ (vb, config)
+ } else {
+ let mut file = std::fs::File::open(config_path)?;
+ let config = Config::from_reader(&mut file)?;
+ println!("{config:?}");
+ let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
+ let vb = weights.var_builder(&config, &device)?;
+ (vb, config)
+ };
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
diff --git a/candle-examples/examples/llama2-c/training.rs b/candle-examples/examples/llama2-c/training.rs
index e55c686c..3e93c786 100644
--- a/candle-examples/examples/llama2-c/training.rs
+++ b/candle-examples/examples/llama2-c/training.rs
@@ -1,118 +1,6 @@
-#![allow(dead_code)]
-#![allow(unused)]
use crate::model::{Cache, Config, Llama};
-use candle::{DType, Device, Result, Tensor};
-
-pub struct Dataset {
- valid_tokens: Vec<memmap2::Mmap>,
- train_tokens: Vec<memmap2::Mmap>,
-}
-
-fn mmap_file(p: &std::path::PathBuf) -> Result<memmap2::Mmap> {
- let file = std::fs::File::open(p)?;
- let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
- Ok(mmap)
-}
-
-impl Dataset {
- pub fn new<P: AsRef<std::path::Path>>(dir: P) -> Result<Self> {
- let dir = dir.as_ref();
- let mut bin_files = vec![];
- for file in std::fs::read_dir(dir)?.flatten() {
- let file = file.path();
- if let Some(extension) = file.extension() {
- if extension == "bin" {
- bin_files.push(file)
- }
- }
- }
- if bin_files.len() < 2 {
- candle::bail!("found less than two bin files in {:?}", dir)
- }
- bin_files.sort();
- let valid_tokens = mmap_file(&bin_files[0])?;
- let train_tokens = bin_files[1..]
- .iter()
- .map(mmap_file)
- .collect::<Result<Vec<_>>>()?;
- Ok(Self {
- valid_tokens: vec![valid_tokens],
- train_tokens,
- })
- }
-}
-
-struct DatasetRandomIter<'a> {
- all_tokens: &'a [memmap2::Mmap],
- tokens: Vec<&'a memmap2::Mmap>,
- current_tokens: &'a memmap2::Mmap,
- indexes_in_bytes: Vec<usize>,
- seq_len: usize,
- device: Device,
-}
-
-impl<'a> DatasetRandomIter<'a> {
- pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
- use rand::seq::SliceRandom;
- use rand::thread_rng;
-
- let all_tokens = if valid {
- &ds.valid_tokens
- } else {
- &ds.train_tokens
- };
- let mut tokens = all_tokens.iter().collect::<Vec<_>>();
- tokens.shuffle(&mut thread_rng());
- let current_tokens = tokens.pop().unwrap();
- let seq_len_in_bytes = seq_len * 2;
- let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
- .step_by(seq_len_in_bytes)
- .collect::<Vec<_>>();
- indexes_in_bytes.shuffle(&mut thread_rng());
- Self {
- all_tokens,
- tokens,
- current_tokens,
- indexes_in_bytes,
- seq_len,
- device,
- }
- }
-}
-
-impl<'a> Iterator for DatasetRandomIter<'a> {
- type Item = Result<(Tensor, Tensor)>;
-
- fn next(&mut self) -> Option<Self::Item> {
- use byteorder::{LittleEndian, ReadBytesExt};
- use rand::seq::SliceRandom;
- use rand::thread_rng;
-
- let seq_len = self.seq_len;
- if self.indexes_in_bytes.is_empty() {
- if self.tokens.is_empty() {
- self.tokens = self.all_tokens.iter().collect();
- self.tokens.shuffle(&mut thread_rng());
- }
- self.current_tokens = self.tokens.pop().unwrap();
- let seq_len_in_bytes = self.seq_len * 2;
- self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
- .step_by(seq_len_in_bytes)
- .collect::<Vec<_>>();
- self.indexes_in_bytes.shuffle(&mut thread_rng());
- }
- let start_idx = self.indexes_in_bytes.pop().unwrap();
- let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
- let mut tokens = vec![0u16; bytes.len() / 2];
- if let Err(err) = std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens) {
- return Some(Err(err.into()));
- }
- let tokens = tokens.into_iter().map(|v| v as u32).collect::<Vec<_>>();
- let inputs = Tensor::new(&tokens[..seq_len], &self.device);
- let targets = Tensor::new(&tokens[1..], &self.device);
- Some(candle::error::zip(inputs, targets))
- }
-}
+use candle::{DType, Device, Result};
+use candle_datasets::nlp::tinystories::{Dataset, DatasetRandomIter};
fn valid_loss(
dataset: &Dataset,
@@ -121,7 +9,7 @@ fn valid_loss(
device: &Device,
) -> Result<f64> {
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
- let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
+ let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
let mut sum_ce = 0f64;
let mut cnt = 0usize;
for inp_tgt in batch_iter.take(50) {
@@ -139,14 +27,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let dataset = Dataset::new(&args.pretokenized_dir)?;
println!(
"loaded dataset, train: {} files, valid: {} files",
- dataset.train_tokens.len(),
- dataset.valid_tokens.len()
+ dataset.train_tokens(),
+ dataset.valid_tokens()
);
let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
let config = Config::tiny();
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
- let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
+ let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
let cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
diff --git a/candle-examples/examples/llama2-c/weights.rs b/candle-examples/examples/llama2-c/weights.rs
index ae1fd6d9..b78418ce 100644
--- a/candle-examples/examples/llama2-c/weights.rs
+++ b/candle-examples/examples/llama2-c/weights.rs
@@ -104,7 +104,14 @@ impl TransformerWeights {
})
}
- pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder> {
+ pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder<'static>> {
+ // TODO: As of 2023-08-04, gemm is slower than expected when multiplying a matrix of
+ // size (1, k) with the transpose of a matrix of size (k, n) as it ends up transposing the
+ // second matrix back. We detect this case here and as a temporary hack make the weight
+ // matrix column major rather than row major. This ends up speeding up text generation from
+ // 120 token/s to 220 token/s on a Ryzen 2600X.
+ let tr = device.is_cpu() && !candle::utils::has_mkl();
+ let tr = |x: Tensor| if tr { x.t()?.contiguous()?.t() } else { Ok(x) };
let mut ws = std::collections::HashMap::new();
let mut insert = |name: &str, t: Tensor| {
ws.insert(name.to_string(), t);
@@ -115,36 +122,36 @@ impl TransformerWeights {
"model.embed_tokens.weight",
self.token_embedding_table.clone(),
);
- insert("lm_head.weight", self.token_embedding_table.clone());
+ insert("lm_head.weight", tr(self.token_embedding_table.clone())?);
insert("model.norm.weight", self.rms_final_weight.clone());
for layer in 0..cfg.n_layers {
ws.insert(
format!("model.layers.{layer}.self_attn.q_proj.weight"),
- self.wq.i(layer)?,
+ tr(self.wq.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.self_attn.k_proj.weight"),
- self.wk.i(layer)?,
+ tr(self.wk.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.self_attn.v_proj.weight"),
- self.wv.i(layer)?,
+ tr(self.wv.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.self_attn.o_proj.weight"),
- self.wo.i(layer)?,
+ tr(self.wo.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.mlp.gate_proj.weight"),
- self.w1.i(layer)?,
+ tr(self.w1.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.mlp.down_proj.weight"),
- self.w2.i(layer)?,
+ tr(self.w2.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.mlp.up_proj.weight"),
- self.w3.i(layer)?,
+ tr(self.w3.i(layer)?)?,
);
ws.insert(
format!("model.layers.{layer}.input_layernorm.weight"),
diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs
index e251f6e9..d9e596ce 100644
--- a/candle-examples/examples/mnist-training/main.rs
+++ b/candle-examples/examples/mnist-training/main.rs
@@ -63,7 +63,7 @@ struct TrainingArgs {
}
fn training_loop<M: Model>(
- m: candle_nn::vision::Dataset,
+ m: candle_datasets::vision::Dataset,
args: &TrainingArgs,
) -> anyhow::Result<()> {
let dev = candle::Device::cuda_if_available(0)?;
@@ -140,7 +140,7 @@ struct Args {
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
// Load the dataset
- let m = candle_nn::vision::mnist::load_dir("data")?;
+ let m = candle_datasets::vision::mnist::load_dir("data")?;
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape());
diff --git a/candle-examples/examples/stable-diffusion/attention.rs b/candle-examples/examples/stable-diffusion/attention.rs
new file mode 100644
index 00000000..83e7ef34
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/attention.rs
@@ -0,0 +1,445 @@
+#![allow(dead_code)]
+//! Attention Based Building Blocks
+use candle::{IndexOp, Result, Tensor, D};
+use candle_nn as nn;
+
+#[derive(Debug)]
+struct GeGlu {
+ proj: nn::Linear,
+}
+
+impl GeGlu {
+ fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
+ let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?;
+ Ok(Self { proj })
+ }
+}
+
+impl GeGlu {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
+ &hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
+ }
+}
+
+/// A feed-forward layer.
+#[derive(Debug)]
+struct FeedForward {
+ project_in: GeGlu,
+ linear: nn::Linear,
+}
+
+impl FeedForward {
+ // The glu parameter in the python code is unused?
+ // https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L347
+ /// Creates a new feed-forward layer based on some given input dimension, some
+ /// output dimension, and a multiplier to be used for the intermediary layer.
+ fn new(vs: nn::VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> {
+ let inner_dim = dim * mult;
+ let dim_out = dim_out.unwrap_or(dim);
+ let vs = vs.pp("net");
+ let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
+ let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?;
+ Ok(Self { project_in, linear })
+ }
+}
+
+impl FeedForward {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.project_in.forward(xs)?;
+ self.linear.forward(&xs)
+ }
+}
+
+#[derive(Debug)]
+struct CrossAttention {
+ to_q: nn::Linear,
+ to_k: nn::Linear,
+ to_v: nn::Linear,
+ to_out: nn::Linear,
+ heads: usize,
+ scale: f64,
+ slice_size: Option<usize>,
+}
+
+impl CrossAttention {
+ // Defaults should be heads = 8, dim_head = 64, context_dim = None
+ fn new(
+ vs: nn::VarBuilder,
+ query_dim: usize,
+ context_dim: Option<usize>,
+ heads: usize,
+ dim_head: usize,
+ slice_size: Option<usize>,
+ ) -> Result<Self> {
+ let inner_dim = dim_head * heads;
+ let context_dim = context_dim.unwrap_or(query_dim);
+ let scale = 1.0 / f64::sqrt(dim_head as f64);
+ let to_q = nn::linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?;
+ let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?;
+ let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?;
+ let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
+ Ok(Self {
+ to_q,
+ to_k,
+ to_v,
+ to_out,
+ heads,
+ scale,
+ slice_size,
+ })
+ }
+
+ fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
+ let (batch_size, seq_len, dim) = xs.dims3()?;
+ xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))?
+ .transpose(1, 2)?
+ .reshape((batch_size * self.heads, seq_len, dim / self.heads))
+ }
+
+ fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> {
+ let (batch_size, seq_len, dim) = xs.dims3()?;
+ xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))?
+ .transpose(1, 2)?
+ .reshape((batch_size / self.heads, seq_len, dim * self.heads))
+ }
+
+ fn sliced_attention(
+ &self,
+ query: &Tensor,
+ key: &Tensor,
+ value: &Tensor,
+ slice_size: usize,
+ ) -> Result<Tensor> {
+ let batch_size_attention = query.dim(0)?;
+ let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size);
+
+ for i in 0..batch_size_attention / slice_size {
+ let start_idx = i * slice_size;
+ let end_idx = (i + 1) * slice_size;
+
+ let xs = query
+ .i(start_idx..end_idx)?
+ .matmul(&(key.i(start_idx..end_idx)?.t()? * self.scale)?)?;
+ let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?;
+ hidden_states.push(xs)
+ }
+ let hidden_states = Tensor::stack(&hidden_states, 0)?;
+ self.reshape_batch_dim_to_heads(&hidden_states)
+ }
+
+ fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
+ let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
+ let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
+ self.reshape_batch_dim_to_heads(&xs)
+ }
+
+ fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
+ let query = self.to_q.forward(xs)?;
+ let context = context.unwrap_or(xs);
+ let key = self.to_k.forward(context)?;
+ let value = self.to_v.forward(context)?;
+ let query = self.reshape_heads_to_batch_dim(&query)?;
+ let key = self.reshape_heads_to_batch_dim(&key)?;
+ let value = self.reshape_heads_to_batch_dim(&value)?;
+ let xs = match self.slice_size {
+ None => self.attention(&query, &key, &value)?,
+ Some(slice_size) => {
+ if query.dim(0)? / slice_size <= 1 {
+ self.attention(&query, &key, &value)?
+ } else {
+ self.sliced_attention(&query, &key, &value, slice_size)?
+ }
+ }
+ };
+ self.to_out.forward(&xs)
+ }
+}
+
+/// A basic Transformer block.
+#[derive(Debug)]
+struct BasicTransformerBlock {
+ attn1: CrossAttention,
+ ff: FeedForward,
+ attn2: CrossAttention,
+ norm1: nn::LayerNorm,
+ norm2: nn::LayerNorm,
+ norm3: nn::LayerNorm,
+}
+
+impl BasicTransformerBlock {
+ fn new(
+ vs: nn::VarBuilder,
+ dim: usize,
+ n_heads: usize,
+ d_head: usize,
+ context_dim: Option<usize>,
+ sliced_attention_size: Option<usize>,
+ ) -> Result<Self> {
+ let attn1 = CrossAttention::new(
+ vs.pp("attn1"),
+ dim,
+ None,
+ n_heads,
+ d_head,
+ sliced_attention_size,
+ )?;
+ let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?;
+ let attn2 = CrossAttention::new(
+ vs.pp("attn2"),
+ dim,
+ context_dim,
+ n_heads,
+ d_head,
+ sliced_attention_size,
+ )?;
+ let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
+ let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
+ let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?;
+ Ok(Self {
+ attn1,
+ ff,
+ attn2,
+ norm1,
+ norm2,
+ norm3,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
+ let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?;
+ let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?;
+ self.ff.forward(&self.norm3.forward(&xs)?)? + xs
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct SpatialTransformerConfig {
+ pub depth: usize,
+ pub num_groups: usize,
+ pub context_dim: Option<usize>,
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for SpatialTransformerConfig {
+ fn default() -> Self {
+ Self {
+ depth: 1,
+ num_groups: 32,
+ context_dim: None,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+enum Proj {
+ Conv2d(nn::Conv2d),
+ Linear(nn::Linear),
+}
+
+// Aka Transformer2DModel
+#[derive(Debug)]
+pub struct SpatialTransformer {
+ norm: nn::GroupNorm,
+ proj_in: Proj,
+ transformer_blocks: Vec<BasicTransformerBlock>,
+ proj_out: Proj,
+ pub config: SpatialTransformerConfig,
+}
+
+impl SpatialTransformer {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ n_heads: usize,
+ d_head: usize,
+ config: SpatialTransformerConfig,
+ ) -> Result<Self> {
+ let inner_dim = n_heads * d_head;
+ let norm = nn::group_norm(config.num_groups, in_channels, 1e-6, vs.pp("norm"))?;
+ let proj_in = if config.use_linear_projection {
+ Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_in"))?)
+ } else {
+ Proj::Conv2d(nn::conv2d(
+ in_channels,
+ inner_dim,
+ 1,
+ Default::default(),
+ vs.pp("proj_in"),
+ )?)
+ };
+ let mut transformer_blocks = vec![];
+ let vs_tb = vs.pp("transformer_blocks");
+ for index in 0..config.depth {
+ let tb = BasicTransformerBlock::new(
+ vs_tb.pp(&index.to_string()),
+ inner_dim,
+ n_heads,
+ d_head,
+ config.context_dim,
+ config.sliced_attention_size,
+ )?;
+ transformer_blocks.push(tb)
+ }
+ let proj_out = if config.use_linear_projection {
+ Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_out"))?)
+ } else {
+ Proj::Conv2d(nn::conv2d(
+ inner_dim,
+ in_channels,
+ 1,
+ Default::default(),
+ vs.pp("proj_out"),
+ )?)
+ };
+ Ok(Self {
+ norm,
+ proj_in,
+ transformer_blocks,
+ proj_out,
+ config,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
+ let (batch, _channel, height, weight) = xs.dims4()?;
+ let residual = xs;
+ let xs = self.norm.forward(xs)?;
+ let (inner_dim, xs) = match &self.proj_in {
+ Proj::Conv2d(p) => {
+ let xs = p.forward(&xs)?;
+ let inner_dim = xs.dim(1)?;
+ let xs = xs
+ .transpose(1, 2)?
+ .t()?
+ .reshape((batch, height * weight, inner_dim))?;
+ (inner_dim, xs)
+ }
+ Proj::Linear(p) => {
+ let inner_dim = xs.dim(1)?;
+ let xs = xs
+ .transpose(1, 2)?
+ .t()?
+ .reshape((batch, height * weight, inner_dim))?;
+ (inner_dim, p.forward(&xs)?)
+ }
+ };
+ let mut xs = xs;
+ for block in self.transformer_blocks.iter() {
+ xs = block.forward(&xs, context)?
+ }
+ let xs = match &self.proj_out {
+ Proj::Conv2d(p) => p.forward(
+ &xs.reshape((batch, height, weight, inner_dim))?
+ .t()?
+ .transpose(1, 2)?,
+ )?,
+ Proj::Linear(p) => p
+ .forward(&xs)?
+ .reshape((batch, height, weight, inner_dim))?
+ .t()?
+ .transpose(1, 2)?,
+ };
+ xs + residual
+ }
+}
+
+/// Configuration for an attention block.
+#[derive(Debug, Clone, Copy)]
+pub struct AttentionBlockConfig {
+ pub num_head_channels: Option<usize>,
+ pub num_groups: usize,
+ pub rescale_output_factor: f64,
+ pub eps: f64,
+}
+
+impl Default for AttentionBlockConfig {
+ fn default() -> Self {
+ Self {
+ num_head_channels: None,
+ num_groups: 32,
+ rescale_output_factor: 1.,
+ eps: 1e-5,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct AttentionBlock {
+ group_norm: nn::GroupNorm,
+ query: nn::Linear,
+ key: nn::Linear,
+ value: nn::Linear,
+ proj_attn: nn::Linear,
+ channels: usize,
+ num_heads: usize,
+ config: AttentionBlockConfig,
+}
+
+impl AttentionBlock {
+ pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result<Self> {
+ let num_head_channels = config.num_head_channels.unwrap_or(channels);
+ let num_heads = channels / num_head_channels;
+ let group_norm =
+ nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
+ let query = nn::linear(channels, channels, vs.pp("query"))?;
+ let key = nn::linear(channels, channels, vs.pp("key"))?;
+ let value = nn::linear(channels, channels, vs.pp("value"))?;
+ let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?;
+ Ok(Self {
+ group_norm,
+ query,
+ key,
+ value,
+ proj_attn,
+ channels,
+ num_heads,
+ config,
+ })
+ }
+
+ fn transpose_for_scores(&self, xs: Tensor) -> Result<Tensor> {
+ let (batch, t, h_times_d) = xs.dims3()?;
+ xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
+ .transpose(1, 2)
+ }
+}
+
+impl AttentionBlock {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let residual = xs;
+ let (batch, channel, height, width) = xs.dims4()?;
+ let xs = self
+ .group_norm
+ .forward(xs)?
+ .reshape((batch, channel, height * width))?
+ .transpose(1, 2)?;
+
+ let query_proj = self.query.forward(&xs)?;
+ let key_proj = self.key.forward(&xs)?;
+ let value_proj = self.value.forward(&xs)?;
+
+ let query_states = self.transpose_for_scores(query_proj)?;
+ let key_states = self.transpose_for_scores(key_proj)?;
+ let value_states = self.transpose_for_scores(value_proj)?;
+
+ let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25);
+ let attention_scores =
+ // TODO: Check that this needs two multiplication by `scale`.
+ (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
+ let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
+
+ let xs = attention_probs.matmul(&value_states)?;
+ let xs = xs.transpose(1, 2)?.contiguous()?;
+ let xs = xs.flatten_from(D::Minus2)?;
+ let xs = self
+ .proj_attn
+ .forward(&xs)?
+ .t()?
+ .reshape((batch, channel, height, width))?;
+ (xs + residual)? / self.config.rescale_output_factor
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs
new file mode 100644
index 00000000..ca00b417
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/clip.rs
@@ -0,0 +1,305 @@
+#![allow(dead_code)]
+//! Contrastive Language-Image Pre-Training
+//!
+//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
+//! pairs of images with related texts.
+//!
+//! https://github.com/openai/CLIP
+use candle::{Device, Result, Tensor, D};
+use candle_nn as nn;
+
+#[derive(Debug, Clone, Copy)]
+pub enum Activation {
+ QuickGelu,
+ Gelu,
+}
+
+impl Activation {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ match self {
+ Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
+ Activation::Gelu => xs.gelu(),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Config {
+ vocab_size: usize,
+ embed_dim: usize, // aka config.hidden_size
+ activation: Activation, // aka config.hidden_act
+ intermediate_size: usize,
+ pub max_position_embeddings: usize,
+ // The character to use for padding, use EOS when not set.
+ pub pad_with: Option<String>,
+ num_hidden_layers: usize,
+ num_attention_heads: usize,
+ #[allow(dead_code)]
+ projection_dim: usize,
+}
+
+impl Config {
+ // The config details can be found in the "text_config" section of this json file:
+ // https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
+ pub fn v1_5() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 768,
+ intermediate_size: 3072,
+ max_position_embeddings: 77,
+ pad_with: None,
+ num_hidden_layers: 12,
+ num_attention_heads: 12,
+ projection_dim: 768,
+ activation: Activation::QuickGelu,
+ }
+ }
+
+ // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/text_encoder/config.json
+ pub fn v2_1() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 1024,
+ intermediate_size: 4096,
+ max_position_embeddings: 77,
+ pad_with: Some("!".to_string()),
+ num_hidden_layers: 23,
+ num_attention_heads: 16,
+ projection_dim: 512,
+ activation: Activation::Gelu,
+ }
+ }
+}
+
+// CLIP Text Model
+// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py
+#[derive(Debug)]
+struct ClipTextEmbeddings {
+ token_embedding: candle_nn::Embedding,
+ position_embedding: candle_nn::Embedding,
+ position_ids: Tensor,
+}
+
+impl ClipTextEmbeddings {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let token_embedding =
+ candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
+ let position_embedding = candle_nn::embedding(
+ c.max_position_embeddings,
+ c.embed_dim,
+ vs.pp("position_embedding"),
+ )?;
+ let position_ids =
+ Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
+ Ok(ClipTextEmbeddings {
+ token_embedding,
+ position_embedding,
+ position_ids,
+ })
+ }
+}
+
+impl ClipTextEmbeddings {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let token_embedding = self.token_embedding.forward(xs)?;
+ let position_embedding = self.position_embedding.forward(&self.position_ids)?;
+ token_embedding.broadcast_add(&position_embedding)
+ }
+}
+
+#[derive(Debug)]
+struct ClipAttention {
+ k_proj: candle_nn::Linear,
+ v_proj: candle_nn::Linear,
+ q_proj: candle_nn::Linear,
+ out_proj: candle_nn::Linear,
+ head_dim: usize,
+ scale: f64,
+ num_attention_heads: usize,
+}
+
+impl ClipAttention {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let embed_dim = c.embed_dim;
+ let num_attention_heads = c.num_attention_heads;
+ let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
+ let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
+ let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
+ let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
+ let head_dim = embed_dim / num_attention_heads;
+ let scale = (head_dim as f64).powf(-0.5);
+ Ok(ClipAttention {
+ k_proj,
+ v_proj,
+ q_proj,
+ out_proj,
+ head_dim,
+ scale,
+ num_attention_heads,
+ })
+ }
+
+ fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
+ xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
+ .transpose(1, 2)?
+ .contiguous()
+ }
+
+ fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
+ let (bsz, seq_len, embed_dim) = xs.dims3()?;
+ let query_states = (self.q_proj.forward(xs)? * self.scale)?;
+ let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
+ let query_states = self
+ .shape(&query_states, seq_len, bsz)?
+ .reshape(proj_shape)?;
+ let key_states = self
+ .shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
+ .reshape(proj_shape)?;
+ let value_states = self
+ .shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
+ .reshape(proj_shape)?;
+ let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
+
+ let src_len = key_states.dim(1)?;
+ let attn_weights = attn_weights
+ .reshape((bsz, self.num_attention_heads, seq_len, src_len))?
+ .broadcast_add(causal_attention_mask)?;
+ let attn_weights =
+ attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
+ let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
+
+ let attn_output = attn_weights.matmul(&value_states)?;
+ let attn_output = attn_output
+ .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
+ .transpose(1, 2)?
+ .reshape((bsz, seq_len, embed_dim))?;
+ self.out_proj.forward(&attn_output)
+ }
+}
+
+#[derive(Debug)]
+struct ClipMlp {
+ fc1: candle_nn::Linear,
+ fc2: candle_nn::Linear,
+ activation: Activation,
+}
+
+impl ClipMlp {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let fc1 = candle_nn::linear(c.embed_dim, c.intermediate_size, vs.pp("fc1"))?;
+ let fc2 = candle_nn::linear(c.intermediate_size, c.embed_dim, vs.pp("fc2"))?;
+ Ok(ClipMlp {
+ fc1,
+ fc2,
+ activation: c.activation,
+ })
+ }
+}
+
+impl ClipMlp {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.fc1.forward(xs)?;
+ self.fc2.forward(&self.activation.forward(&xs)?)
+ }
+}
+
+#[derive(Debug)]
+struct ClipEncoderLayer {
+ self_attn: ClipAttention,
+ layer_norm1: candle_nn::LayerNorm,
+ mlp: ClipMlp,
+ layer_norm2: candle_nn::LayerNorm,
+}
+
+impl ClipEncoderLayer {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
+ let layer_norm1 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm1"))?;
+ let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
+ let layer_norm2 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm2"))?;
+ Ok(ClipEncoderLayer {
+ self_attn,
+ layer_norm1,
+ mlp,
+ layer_norm2,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
+ let residual = xs;
+ let xs = self.layer_norm1.forward(xs)?;
+ let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
+ let xs = (xs + residual)?;
+
+ let residual = &xs;
+ let xs = self.layer_norm2.forward(&xs)?;
+ let xs = self.mlp.forward(&xs)?;
+ xs + residual
+ }
+}
+
+#[derive(Debug)]
+struct ClipEncoder {
+ layers: Vec<ClipEncoderLayer>,
+}
+
+impl ClipEncoder {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let vs = vs.pp("layers");
+ let mut layers: Vec<ClipEncoderLayer> = Vec::new();
+ for index in 0..c.num_hidden_layers {
+ let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
+ layers.push(layer)
+ }
+ Ok(ClipEncoder { layers })
+ }
+
+ fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for layer in self.layers.iter() {
+ xs = layer.forward(&xs, causal_attention_mask)?;
+ }
+ Ok(xs)
+ }
+}
+
+/// A CLIP transformer based model.
+#[derive(Debug)]
+pub struct ClipTextTransformer {
+ embeddings: ClipTextEmbeddings,
+ encoder: ClipEncoder,
+ final_layer_norm: candle_nn::LayerNorm,
+}
+
+impl ClipTextTransformer {
+ pub fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let vs = vs.pp("text_model");
+ let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
+ let encoder = ClipEncoder::new(vs.pp("encoder"), c)?;
+ let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
+ Ok(ClipTextTransformer {
+ embeddings,
+ encoder,
+ final_layer_norm,
+ })
+ }
+
+ // https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
+ fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
+ let mask: Vec<_> = (0..seq_len)
+ .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. }))
+ .collect();
+ let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
+ mask.broadcast_as((bsz, seq_len, seq_len))
+ }
+}
+
+impl ClipTextTransformer {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let (bsz, seq_len) = xs.dims2()?;
+ let xs = self.embeddings.forward(xs)?;
+ let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?;
+ let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
+ self.final_layer_norm.forward(&xs)
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/ddim.rs b/candle-examples/examples/stable-diffusion/ddim.rs
new file mode 100644
index 00000000..6eb6df44
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/ddim.rs
@@ -0,0 +1,181 @@
+#![allow(dead_code)]
+//! # Denoising Diffusion Implicit Models
+//!
+//! The Denoising Diffusion Implicit Models (DDIM) is a simple scheduler
+//! similar to Denoising Diffusion Probabilistic Models (DDPM). The DDPM
+//! generative process is the reverse of a Markovian process, DDIM generalizes
+//! this to non-Markovian guidance.
+//!
+//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
+//! https://arxiv.org/abs/2010.02502
+use crate::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
+use candle::{Result, Tensor};
+
+/// The configuration for the DDIM scheduler.
+#[derive(Debug, Clone, Copy)]
+pub struct DDIMSchedulerConfig {
+ /// The value of beta at the beginning of training.
+ pub beta_start: f64,
+ /// The value of beta at the end of training.
+ pub beta_end: f64,
+ /// How beta evolved during training.
+ pub beta_schedule: BetaSchedule,
+ /// The amount of noise to be added at each step.
+ pub eta: f64,
+ /// Adjust the indexes of the inference schedule by this value.
+ pub steps_offset: usize,
+ /// prediction type of the scheduler function, one of `epsilon` (predicting
+ /// the noise of the diffusion process), `sample` (directly predicting the noisy sample`)
+ /// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf)
+ pub prediction_type: PredictionType,
+ /// number of diffusion steps used to train the model
+ pub train_timesteps: usize,
+}
+
+impl Default for DDIMSchedulerConfig {
+ fn default() -> Self {
+ Self {
+ beta_start: 0.00085f64,
+ beta_end: 0.012f64,
+ beta_schedule: BetaSchedule::ScaledLinear,
+ eta: 0.,
+ steps_offset: 1,
+ prediction_type: PredictionType::Epsilon,
+ train_timesteps: 1000,
+ }
+ }
+}
+
+/// The DDIM scheduler.
+#[derive(Debug, Clone)]
+pub struct DDIMScheduler {
+ timesteps: Vec<usize>,
+ alphas_cumprod: Vec<f64>,
+ step_ratio: usize,
+ init_noise_sigma: f64,
+ pub config: DDIMSchedulerConfig,
+}
+
+// clip_sample: False, set_alpha_to_one: False
+impl DDIMScheduler {
+ /// Creates a new DDIM scheduler given the number of steps to be
+ /// used for inference as well as the number of steps that was used
+ /// during training.
+ pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
+ let step_ratio = config.train_timesteps / inference_steps;
+ let timesteps: Vec<usize> = (0..(inference_steps))
+ .map(|s| s * step_ratio + config.steps_offset)
+ .rev()
+ .collect();
+ let betas = match config.beta_schedule {
+ BetaSchedule::ScaledLinear => crate::utils::linspace(
+ config.beta_start.sqrt(),
+ config.beta_end.sqrt(),
+ config.train_timesteps,
+ )?
+ .sqr()?,
+ BetaSchedule::Linear => {
+ crate::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
+ }
+ BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
+ };
+ let betas = betas.to_vec1::<f64>()?;
+ let mut alphas_cumprod = Vec::with_capacity(betas.len());
+ for &beta in betas.iter() {
+ let alpha = 1.0 - beta;
+ alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
+ }
+ Ok(Self {
+ alphas_cumprod,
+ timesteps,
+ step_ratio,
+ init_noise_sigma: 1.,
+ config,
+ })
+ }
+
+ pub fn timesteps(&self) -> &[usize] {
+ self.timesteps.as_slice()
+ }
+
+ /// Ensures interchangeability with schedulers that need to scale the denoising model input
+ /// depending on the current timestep.
+ pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
+ Ok(sample)
+ }
+
+ /// Performs a backward step during inference.
+ pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
+ let timestep = if timestep >= self.alphas_cumprod.len() {
+ timestep - 1
+ } else {
+ timestep
+ };
+ // https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195
+ let prev_timestep = if timestep > self.step_ratio {
+ timestep - self.step_ratio
+ } else {
+ 0
+ };
+
+ let alpha_prod_t = self.alphas_cumprod[timestep];
+ let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep];
+ let beta_prod_t = 1. - alpha_prod_t;
+ let beta_prod_t_prev = 1. - alpha_prod_t_prev;
+
+ let (pred_original_sample, pred_epsilon) = match self.config.prediction_type {
+ PredictionType::Epsilon => {
+ let pred_original_sample = ((sample - (model_output * beta_prod_t.sqrt())?)?
+ * (1. / alpha_prod_t.sqrt()))?;
+ (pred_original_sample, model_output.clone())
+ }
+ PredictionType::VPrediction => {
+ let pred_original_sample =
+ ((sample * alpha_prod_t.sqrt())? - (model_output * beta_prod_t.sqrt())?)?;
+ let pred_epsilon =
+ ((model_output * alpha_prod_t.sqrt())? + (sample * beta_prod_t.sqrt())?)?;
+ (pred_original_sample, pred_epsilon)
+ }
+ PredictionType::Sample => {
+ let pred_original_sample = model_output.clone();
+ let pred_epsilon = ((sample - &pred_original_sample * alpha_prod_t.sqrt())?
+ * (1. / beta_prod_t.sqrt()))?;
+ (pred_original_sample, pred_epsilon)
+ }
+ };
+
+ let variance = (beta_prod_t_prev / beta_prod_t) * (1. - alpha_prod_t / alpha_prod_t_prev);
+ let std_dev_t = self.config.eta * variance.sqrt();
+
+ let pred_sample_direction =
+ (pred_epsilon * (1. - alpha_prod_t_prev - std_dev_t * std_dev_t).sqrt())?;
+ let prev_sample =
+ ((pred_original_sample * alpha_prod_t_prev.sqrt())? + pred_sample_direction)?;
+ if self.config.eta > 0. {
+ &prev_sample
+ + Tensor::randn(
+ 0f32,
+ std_dev_t as f32,
+ prev_sample.shape(),
+ prev_sample.device(),
+ )?
+ } else {
+ Ok(prev_sample)
+ }
+ }
+
+ pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
+ let timestep = if timestep >= self.alphas_cumprod.len() {
+ timestep - 1
+ } else {
+ timestep
+ };
+ let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt();
+ let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt();
+ (original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)?
+ }
+
+ pub fn init_noise_sigma(&self) -> f64 {
+ self.init_noise_sigma
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/embeddings.rs b/candle-examples/examples/stable-diffusion/embeddings.rs
new file mode 100644
index 00000000..e3a339f5
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/embeddings.rs
@@ -0,0 +1,65 @@
+#![allow(dead_code)]
+use candle::{Result, Tensor, D};
+use candle_nn as nn;
+
+#[derive(Debug)]
+pub struct TimestepEmbedding {
+ linear_1: nn::Linear,
+ linear_2: nn::Linear,
+}
+
+impl TimestepEmbedding {
+ // act_fn: "silu"
+ pub fn new(vs: nn::VarBuilder, channel: usize, time_embed_dim: usize) -> Result<Self> {
+ let linear_1 = nn::linear(channel, time_embed_dim, vs.pp("linear_1"))?;
+ let linear_2 = nn::linear(time_embed_dim, time_embed_dim, vs.pp("linear_2"))?;
+ Ok(Self { linear_1, linear_2 })
+ }
+}
+
+impl TimestepEmbedding {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?;
+ self.linear_2.forward(&xs)
+ }
+}
+
+#[derive(Debug)]
+pub struct Timesteps {
+ num_channels: usize,
+ flip_sin_to_cos: bool,
+ downscale_freq_shift: f64,
+}
+
+impl Timesteps {
+ pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self {
+ Self {
+ num_channels,
+ flip_sin_to_cos,
+ downscale_freq_shift,
+ }
+ }
+}
+
+impl Timesteps {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let half_dim = (self.num_channels / 2) as u32;
+ let exponent =
+ (Tensor::arange(0, half_dim, xs.device())?.to_dtype(xs.dtype())? * -f64::ln(10000.))?;
+ let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
+ let emb = exponent.exp()?;
+ // emb = timesteps[:, None].float() * emb[None, :]
+ let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
+ let (cos, sin) = (emb.cos()?, emb.sin()?);
+ let emb = if self.flip_sin_to_cos {
+ Tensor::cat(&[&cos, &sin], D::Minus1)?
+ } else {
+ Tensor::cat(&[&sin, &cos], D::Minus1)?
+ };
+ if self.num_channels % 2 == 1 {
+ emb.pad_with_zeros(D::Minus2, 0, 1)
+ } else {
+ Ok(emb)
+ }
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs
new file mode 100644
index 00000000..8ce0c234
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/main.rs
@@ -0,0 +1,273 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+mod attention;
+mod clip;
+mod ddim;
+mod embeddings;
+mod resnet;
+mod schedulers;
+mod stable_diffusion;
+mod unet_2d;
+mod unet_2d_blocks;
+mod utils;
+mod vae;
+
+use anyhow::{Error as E, Result};
+use candle::{DType, Device, Tensor};
+use clap::Parser;
+use tokenizers::Tokenizer;
+
+const GUIDANCE_SCALE: f64 = 7.5;
+
+#[derive(Parser)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// The prompt to be used for image generation.
+ #[arg(
+ long,
+ default_value = "A very realistic photo of a rusty robot walking on a sandy beach"
+ )]
+ prompt: String,
+
+ #[arg(long, default_value = "")]
+ uncond_prompt: String,
+
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// The height in pixels of the generated image.
+ #[arg(long)]
+ height: Option<usize>,
+
+ /// The width in pixels of the generated image.
+ #[arg(long)]
+ width: Option<usize>,
+
+ /// The UNet weight file, in .ot or .safetensors format.
+ #[arg(long, value_name = "FILE")]
+ unet_weights: Option<String>,
+
+ /// The CLIP weight file, in .ot or .safetensors format.
+ #[arg(long, value_name = "FILE")]
+ clip_weights: Option<String>,
+
+ /// The VAE weight file, in .ot or .safetensors format.
+ #[arg(long, value_name = "FILE")]
+ vae_weights: Option<String>,
+
+ #[arg(long, value_name = "FILE")]
+ /// The file specifying the tokenizer to used for tokenization.
+ tokenizer: String,
+
+ /// The size of the sliced attention or 0 for automatic slicing (disabled by default)
+ #[arg(long)]
+ sliced_attention_size: Option<usize>,
+
+ /// The number of steps to run the diffusion for.
+ #[arg(long, default_value_t = 30)]
+ n_steps: usize,
+
+ /// The number of samples to generate.
+ #[arg(long, default_value_t = 1)]
+ num_samples: i64,
+
+ /// The name of the final image to generate.
+ #[arg(long, value_name = "FILE", default_value = "sd_final.png")]
+ final_image: String,
+
+ #[arg(long, value_enum, default_value = "v2-1")]
+ sd_version: StableDiffusionVersion,
+
+ /// Generate intermediary images at each step.
+ #[arg(long, action)]
+ intermediary_images: bool,
+}
+
+#[derive(Debug, Clone, Copy, clap::ValueEnum)]
+enum StableDiffusionVersion {
+ V1_5,
+ V2_1,
+}
+
+impl Args {
+ fn clip_weights(&self) -> String {
+ match &self.clip_weights {
+ Some(w) => w.clone(),
+ None => match self.sd_version {
+ StableDiffusionVersion::V1_5 => "data/pytorch_model.safetensors".to_string(),
+ StableDiffusionVersion::V2_1 => "data/clip_v2.1.safetensors".to_string(),
+ },
+ }
+ }
+
+ fn vae_weights(&self) -> String {
+ match &self.vae_weights {
+ Some(w) => w.clone(),
+ None => match self.sd_version {
+ StableDiffusionVersion::V1_5 => "data/vae.safetensors".to_string(),
+ StableDiffusionVersion::V2_1 => "data/vae_v2.1.safetensors".to_string(),
+ },
+ }
+ }
+
+ fn unet_weights(&self) -> String {
+ match &self.unet_weights {
+ Some(w) => w.clone(),
+ None => match self.sd_version {
+ StableDiffusionVersion::V1_5 => "data/unet.safetensors".to_string(),
+ StableDiffusionVersion::V2_1 => "data/unet_v2.1.safetensors".to_string(),
+ },
+ }
+ }
+}
+
+fn output_filename(
+ basename: &str,
+ sample_idx: i64,
+ num_samples: i64,
+ timestep_idx: Option<usize>,
+) -> String {
+ let filename = if num_samples > 1 {
+ match basename.rsplit_once('.') {
+ None => format!("{basename}.{sample_idx}.png"),
+ Some((filename_no_extension, extension)) => {
+ format!("{filename_no_extension}.{sample_idx}.{extension}")
+ }
+ }
+ } else {
+ basename.to_string()
+ };
+ match timestep_idx {
+ None => filename,
+ Some(timestep_idx) => match filename.rsplit_once('.') {
+ None => format!("{filename}-{timestep_idx}.png"),
+ Some((filename_no_extension, extension)) => {
+ format!("{filename_no_extension}-{timestep_idx}.{extension}")
+ }
+ },
+ }
+}
+
+fn run(args: Args) -> Result<()> {
+ let clip_weights = args.clip_weights();
+ let vae_weights = args.vae_weights();
+ let unet_weights = args.unet_weights();
+ let Args {
+ prompt,
+ uncond_prompt,
+ cpu,
+ height,
+ width,
+ n_steps,
+ tokenizer,
+ final_image,
+ sliced_attention_size,
+ num_samples,
+ sd_version,
+ ..
+ } = args;
+ let sd_config = match sd_version {
+ StableDiffusionVersion::V1_5 => {
+ stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width)
+ }
+ StableDiffusionVersion::V2_1 => {
+ stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width)
+ }
+ };
+
+ let scheduler = sd_config.build_scheduler(n_steps)?;
+ let device = candle_examples::device(cpu)?;
+
+ let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
+ let pad_id = match &sd_config.clip.pad_with {
+ Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
+ None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
+ };
+ println!("Running with prompt \"{prompt}\".");
+ let mut tokens = tokenizer
+ .encode(prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ while tokens.len() < sd_config.clip.max_position_embeddings {
+ tokens.push(pad_id)
+ }
+ let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
+
+ let mut uncond_tokens = tokenizer
+ .encode(uncond_prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
+ uncond_tokens.push(pad_id)
+ }
+ let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
+
+ println!("Building the Clip transformer.");
+ let text_model = sd_config.build_clip_transformer(&clip_weights, &device)?;
+ let text_embeddings = text_model.forward(&tokens)?;
+ let uncond_embeddings = text_model.forward(&uncond_tokens)?;
+ let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;
+
+ println!("Building the autoencoder.");
+ let vae = sd_config.build_vae(&vae_weights, &device)?;
+ println!("Building the unet.");
+ let unet = sd_config.build_unet(&unet_weights, &device, 4)?;
+
+ let bsize = 1;
+ for idx in 0..num_samples {
+ let mut latents = Tensor::randn(
+ 0f32,
+ 1f32,
+ (bsize, 4, sd_config.height / 8, sd_config.width / 8),
+ &device,
+ )?;
+
+ // scale the initial noise by the standard deviation required by the scheduler
+ latents = (latents * scheduler.init_noise_sigma())?;
+
+ for (timestep_index, &timestep) in scheduler.timesteps().iter().enumerate() {
+ println!("Timestep {timestep_index}/{n_steps}");
+ let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
+
+ let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?;
+ let noise_pred =
+ unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?;
+ let noise_pred = noise_pred.chunk(2, 0)?;
+ let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
+ let noise_pred =
+ (noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?;
+ latents = scheduler.step(&noise_pred, timestep, &latents)?;
+
+ if args.intermediary_images {
+ let image = vae.decode(&(&latents / 0.18215)?)?;
+ let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
+ let image = (image * 255.)?.to_dtype(DType::U8)?;
+ let image_filename =
+ output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1));
+ crate::utils::save_image(&image, image_filename)?
+ }
+ }
+
+ println!(
+ "Generating the final image for sample {}/{}.",
+ idx + 1,
+ num_samples
+ );
+ let image = vae.decode(&(&latents / 0.18215)?)?;
+ // TODO: Add the clamping between 0 and 1.
+ let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
+ let image = (image * 255.)?.to_dtype(DType::U8)?;
+ let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
+ crate::utils::save_image(&image, image_filename)?
+ }
+ Ok(())
+}
+
+fn main() -> Result<()> {
+ let args = Args::parse();
+ run(args)
+}
diff --git a/candle-examples/examples/stable-diffusion/resnet.rs b/candle-examples/examples/stable-diffusion/resnet.rs
new file mode 100644
index 00000000..7790dcf9
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/resnet.rs
@@ -0,0 +1,129 @@
+#![allow(dead_code)]
+//! ResNet Building Blocks
+//!
+//! Some Residual Network blocks used in UNet models.
+//!
+//! Denoising Diffusion Implicit Models, K. He and al, 2015.
+//! https://arxiv.org/abs/1512.03385
+use candle::{Result, Tensor, D};
+use candle_nn as nn;
+
+/// Configuration for a ResNet block.
+#[derive(Debug, Clone, Copy)]
+pub struct ResnetBlock2DConfig {
+ /// The number of output channels, defaults to the number of input channels.
+ pub out_channels: Option<usize>,
+ pub temb_channels: Option<usize>,
+ /// The number of groups to use in group normalization.
+ pub groups: usize,
+ pub groups_out: Option<usize>,
+ /// The epsilon to be used in the group normalization operations.
+ pub eps: f64,
+ /// Whether to use a 2D convolution in the skip connection. When using None,
+ /// such a convolution is used if the number of input channels is different from
+ /// the number of output channels.
+ pub use_in_shortcut: Option<bool>,
+ // non_linearity: silu
+ /// The final output is scaled by dividing by this value.
+ pub output_scale_factor: f64,
+}
+
+impl Default for ResnetBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ out_channels: None,
+ temb_channels: Some(512),
+ groups: 32,
+ groups_out: None,
+ eps: 1e-6,
+ use_in_shortcut: None,
+ output_scale_factor: 1.,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct ResnetBlock2D {
+ norm1: nn::GroupNorm,
+ conv1: nn::Conv2d,
+ norm2: nn::GroupNorm,
+ conv2: nn::Conv2d,
+ time_emb_proj: Option<nn::Linear>,
+ conv_shortcut: Option<nn::Conv2d>,
+ config: ResnetBlock2DConfig,
+}
+
+impl ResnetBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ config: ResnetBlock2DConfig,
+ ) -> Result<Self> {
+ let out_channels = config.out_channels.unwrap_or(in_channels);
+ let conv_cfg = nn::Conv2dConfig {
+ stride: 1,
+ padding: 1,
+ };
+ let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
+ let conv1 = nn::conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
+ let groups_out = config.groups_out.unwrap_or(config.groups);
+ let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?;
+ let conv2 = nn::conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
+ let use_in_shortcut = config
+ .use_in_shortcut
+ .unwrap_or(in_channels != out_channels);
+ let conv_shortcut = if use_in_shortcut {
+ let conv_cfg = nn::Conv2dConfig {
+ stride: 1,
+ padding: 0,
+ };
+ Some(nn::conv2d(
+ in_channels,
+ out_channels,
+ 1,
+ conv_cfg,
+ vs.pp("conv_shortcut"),
+ )?)
+ } else {
+ None
+ };
+ let time_emb_proj = match config.temb_channels {
+ None => None,
+ Some(temb_channels) => Some(nn::linear(
+ temb_channels,
+ out_channels,
+ vs.pp("time_emb_proj"),
+ )?),
+ };
+ Ok(Self {
+ norm1,
+ conv1,
+ norm2,
+ conv2,
+ time_emb_proj,
+ config,
+ conv_shortcut,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
+ let shortcut_xs = match &self.conv_shortcut {
+ Some(conv_shortcut) => conv_shortcut.forward(xs)?,
+ None => xs.clone(),
+ };
+ let xs = self.norm1.forward(xs)?;
+ let xs = self.conv1.forward(&nn::ops::silu(&xs)?)?;
+ let xs = match (temb, &self.time_emb_proj) {
+ (Some(temb), Some(time_emb_proj)) => time_emb_proj
+ .forward(&nn::ops::silu(temb)?)?
+ .unsqueeze(D::Minus1)?
+ .unsqueeze(D::Minus1)?
+ .broadcast_add(&xs)?,
+ _ => xs,
+ };
+ let xs = self
+ .conv2
+ .forward(&nn::ops::silu(&self.norm2.forward(&xs)?)?)?;
+ (shortcut_xs + xs)? / self.config.output_scale_factor
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/schedulers.rs b/candle-examples/examples/stable-diffusion/schedulers.rs
new file mode 100644
index 00000000..3f6a1d72
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/schedulers.rs
@@ -0,0 +1,45 @@
+#![allow(dead_code)]
+//! # Diffusion pipelines and models
+//!
+//! Noise schedulers can be used to set the trade-off between
+//! inference speed and quality.
+
+use candle::{Result, Tensor};
+
+/// This represents how beta ranges from its minimum value to the maximum
+/// during training.
+#[derive(Debug, Clone, Copy)]
+pub enum BetaSchedule {
+ /// Linear interpolation.
+ Linear,
+ /// Linear interpolation of the square root of beta.
+ ScaledLinear,
+ /// Glide cosine schedule
+ SquaredcosCapV2,
+}
+
+#[derive(Debug, Clone, Copy)]
+pub enum PredictionType {
+ Epsilon,
+ VPrediction,
+ Sample,
+}
+
+/// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+/// `(1-beta)` over time from `t = [0,1]`.
+///
+/// Contains a function `alpha_bar` that takes an argument `t` and transforms it to the cumulative product of `(1-beta)`
+/// up to that part of the diffusion process.
+pub(crate) fn betas_for_alpha_bar(num_diffusion_timesteps: usize, max_beta: f64) -> Result<Tensor> {
+ let alpha_bar = |time_step: usize| {
+ f64::cos((time_step as f64 + 0.008) / 1.008 * std::f64::consts::FRAC_PI_2).powi(2)
+ };
+ let mut betas = Vec::with_capacity(num_diffusion_timesteps);
+ for i in 0..num_diffusion_timesteps {
+ let t1 = i / num_diffusion_timesteps;
+ let t2 = (i + 1) / num_diffusion_timesteps;
+ betas.push((1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta));
+ }
+ let betas_len = betas.len();
+ Tensor::from_vec(betas, betas_len, &candle::Device::Cpu)
+}
diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
new file mode 100644
index 00000000..c250ed56
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
@@ -0,0 +1,212 @@
+#![allow(dead_code)]
+use crate::schedulers::PredictionType;
+use crate::{clip, ddim, unet_2d, vae};
+use candle::{DType, Device, Result};
+use candle_nn as nn;
+
+#[derive(Clone, Debug)]
+pub struct StableDiffusionConfig {
+ pub width: usize,
+ pub height: usize,
+ pub clip: clip::Config,
+ autoencoder: vae::AutoEncoderKLConfig,
+ unet: unet_2d::UNet2DConditionModelConfig,
+ scheduler: ddim::DDIMSchedulerConfig,
+}
+
+impl StableDiffusionConfig {
+ pub fn v1_5(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ ) -> Self {
+ let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ };
+ // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json
+ let unet = unet_2d::UNet2DConditionModelConfig {
+ blocks: vec![
+ bc(320, true, 8),
+ bc(640, true, 8),
+ bc(1280, true, 8),
+ bc(1280, false, 8),
+ ],
+ center_input_sample: false,
+ cross_attention_dim: 768,
+ downsample_padding: 1,
+ flip_sin_to_cos: true,
+ freq_shift: 0.,
+ layers_per_block: 2,
+ mid_block_scale_factor: 1.,
+ norm_eps: 1e-5,
+ norm_num_groups: 32,
+ sliced_attention_size,
+ use_linear_projection: false,
+ };
+ let autoencoder = vae::AutoEncoderKLConfig {
+ block_out_channels: vec![128, 256, 512, 512],
+ layers_per_block: 2,
+ latent_channels: 4,
+ norm_num_groups: 32,
+ };
+ let height = if let Some(height) = height {
+ assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
+ height
+ } else {
+ 512
+ };
+
+ let width = if let Some(width) = width {
+ assert_eq!(width % 8, 0, "width has to be divisible by 8");
+ width
+ } else {
+ 512
+ };
+
+ Self {
+ width,
+ height,
+ clip: clip::Config::v1_5(),
+ autoencoder,
+ scheduler: Default::default(),
+ unet,
+ }
+ }
+
+ fn v2_1_(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ prediction_type: PredictionType,
+ ) -> Self {
+ let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ };
+ // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json
+ let unet = unet_2d::UNet2DConditionModelConfig {
+ blocks: vec![
+ bc(320, true, 5),
+ bc(640, true, 10),
+ bc(1280, true, 20),
+ bc(1280, false, 20),
+ ],
+ center_input_sample: false,
+ cross_attention_dim: 1024,
+ downsample_padding: 1,
+ flip_sin_to_cos: true,
+ freq_shift: 0.,
+ layers_per_block: 2,
+ mid_block_scale_factor: 1.,
+ norm_eps: 1e-5,
+ norm_num_groups: 32,
+ sliced_attention_size,
+ use_linear_projection: true,
+ };
+ // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/vae/config.json
+ let autoencoder = vae::AutoEncoderKLConfig {
+ block_out_channels: vec![128, 256, 512, 512],
+ layers_per_block: 2,
+ latent_channels: 4,
+ norm_num_groups: 32,
+ };
+ let scheduler = ddim::DDIMSchedulerConfig {
+ prediction_type,
+ ..Default::default()
+ };
+
+ let height = if let Some(height) = height {
+ assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
+ height
+ } else {
+ 768
+ };
+
+ let width = if let Some(width) = width {
+ assert_eq!(width % 8, 0, "width has to be divisible by 8");
+ width
+ } else {
+ 768
+ };
+
+ Self {
+ width,
+ height,
+ clip: clip::Config::v2_1(),
+ autoencoder,
+ scheduler,
+ unet,
+ }
+ }
+
+ pub fn v2_1(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ ) -> Self {
+ // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/scheduler/scheduler_config.json
+ Self::v2_1_(
+ sliced_attention_size,
+ height,
+ width,
+ PredictionType::VPrediction,
+ )
+ }
+
+ pub fn v2_1_inpaint(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ ) -> Self {
+ // https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/scheduler/scheduler_config.json
+ // This uses a PNDM scheduler rather than DDIM but the biggest difference is the prediction
+ // type being "epsilon" by default and not "v_prediction".
+ Self::v2_1_(
+ sliced_attention_size,
+ height,
+ width,
+ PredictionType::Epsilon,
+ )
+ }
+
+ pub fn build_vae(&self, vae_weights: &str, device: &Device) -> Result<vae::AutoEncoderKL> {
+ let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? };
+ let weights = weights.deserialize()?;
+ let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
+ // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
+ let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?;
+ Ok(autoencoder)
+ }
+
+ pub fn build_unet(
+ &self,
+ unet_weights: &str,
+ device: &Device,
+ in_channels: usize,
+ ) -> Result<unet_2d::UNet2DConditionModel> {
+ let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? };
+ let weights = weights.deserialize()?;
+ let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
+ let unet = unet_2d::UNet2DConditionModel::new(vs_unet, in_channels, 4, self.unet.clone())?;
+ Ok(unet)
+ }
+
+ pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
+ ddim::DDIMScheduler::new(n_steps, self.scheduler)
+ }
+
+ pub fn build_clip_transformer(
+ &self,
+ clip_weights: &str,
+ device: &Device,
+ ) -> Result<clip::ClipTextTransformer> {
+ let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
+ let weights = weights.deserialize()?;
+ let vs = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
+ let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?;
+ Ok(text_model)
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs
new file mode 100644
index 00000000..8ebd1876
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/unet_2d.rs
@@ -0,0 +1,383 @@
+#![allow(dead_code)]
+//! 2D UNet Denoising Models
+//!
+//! The 2D Unet models take as input a noisy sample and the current diffusion
+//! timestep and return a denoised version of the input.
+use crate::embeddings::{TimestepEmbedding, Timesteps};
+use crate::unet_2d_blocks::*;
+use candle::{DType, Result, Tensor};
+use candle_nn as nn;
+
+#[derive(Debug, Clone, Copy)]
+pub struct BlockConfig {
+ pub out_channels: usize,
+ pub use_cross_attn: bool,
+ pub attention_head_dim: usize,
+}
+
+#[derive(Debug, Clone)]
+pub struct UNet2DConditionModelConfig {
+ pub center_input_sample: bool,
+ pub flip_sin_to_cos: bool,
+ pub freq_shift: f64,
+ pub blocks: Vec<BlockConfig>,
+ pub layers_per_block: usize,
+ pub downsample_padding: usize,
+ pub mid_block_scale_factor: f64,
+ pub norm_num_groups: usize,
+ pub norm_eps: f64,
+ pub cross_attention_dim: usize,
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for UNet2DConditionModelConfig {
+ fn default() -> Self {
+ Self {
+ center_input_sample: false,
+ flip_sin_to_cos: true,
+ freq_shift: 0.,
+ blocks: vec![
+ BlockConfig {
+ out_channels: 320,
+ use_cross_attn: true,
+ attention_head_dim: 8,
+ },
+ BlockConfig {
+ out_channels: 640,
+ use_cross_attn: true,
+ attention_head_dim: 8,
+ },
+ BlockConfig {
+ out_channels: 1280,
+ use_cross_attn: true,
+ attention_head_dim: 8,
+ },
+ BlockConfig {
+ out_channels: 1280,
+ use_cross_attn: false,
+ attention_head_dim: 8,
+ },
+ ],
+ layers_per_block: 2,
+ downsample_padding: 1,
+ mid_block_scale_factor: 1.,
+ norm_num_groups: 32,
+ norm_eps: 1e-5,
+ cross_attention_dim: 1280,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub(crate) enum UNetDownBlock {
+ Basic(DownBlock2D),
+ CrossAttn(CrossAttnDownBlock2D),
+}
+
+#[derive(Debug)]
+enum UNetUpBlock {
+ Basic(UpBlock2D),
+ CrossAttn(CrossAttnUpBlock2D),
+}
+
+#[derive(Debug)]
+pub struct UNet2DConditionModel {
+ conv_in: nn::Conv2d,
+ time_proj: Timesteps,
+ time_embedding: TimestepEmbedding,
+ down_blocks: Vec<UNetDownBlock>,
+ mid_block: UNetMidBlock2DCrossAttn,
+ up_blocks: Vec<UNetUpBlock>,
+ conv_norm_out: nn::GroupNorm,
+ conv_out: nn::Conv2d,
+ config: UNet2DConditionModelConfig,
+}
+
+impl UNet2DConditionModel {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: UNet2DConditionModelConfig,
+ ) -> Result<Self> {
+ let n_blocks = config.blocks.len();
+ let b_channels = config.blocks[0].out_channels;
+ let bl_channels = config.blocks.last().unwrap().out_channels;
+ let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim;
+ let time_embed_dim = b_channels * 4;
+ let conv_cfg = nn::Conv2dConfig {
+ stride: 1,
+ padding: 1,
+ };
+ let conv_in = nn::conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
+
+ let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift);
+ let time_embedding =
+ TimestepEmbedding::new(vs.pp("time_embedding"), b_channels, time_embed_dim)?;
+
+ let vs_db = vs.pp("down_blocks");
+ let down_blocks = (0..n_blocks)
+ .map(|i| {
+ let BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ } = config.blocks[i];
+
+ // Enable automatic attention slicing if the config sliced_attention_size is set to 0.
+ let sliced_attention_size = match config.sliced_attention_size {
+ Some(0) => Some(attention_head_dim / 2),
+ _ => config.sliced_attention_size,
+ };
+
+ let in_channels = if i > 0 {
+ config.blocks[i - 1].out_channels
+ } else {
+ b_channels
+ };
+ let db_cfg = DownBlock2DConfig {
+ num_layers: config.layers_per_block,
+ resnet_eps: config.norm_eps,
+ resnet_groups: config.norm_num_groups,
+ add_downsample: i < n_blocks - 1,
+ downsample_padding: config.downsample_padding,
+ ..Default::default()
+ };
+ if use_cross_attn {
+ let config = CrossAttnDownBlock2DConfig {
+ downblock: db_cfg,
+ attn_num_head_channels: attention_head_dim,
+ cross_attention_dim: config.cross_attention_dim,
+ sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let block = CrossAttnDownBlock2D::new(
+ vs_db.pp(&i.to_string()),
+ in_channels,
+ out_channels,
+ Some(time_embed_dim),
+ config,
+ )?;
+ Ok(UNetDownBlock::CrossAttn(block))
+ } else {
+ let block = DownBlock2D::new(
+ vs_db.pp(&i.to_string()),
+ in_channels,
+ out_channels,
+ Some(time_embed_dim),
+ db_cfg,
+ )?;
+ Ok(UNetDownBlock::Basic(block))
+ }
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ let mid_cfg = UNetMidBlock2DCrossAttnConfig {
+ resnet_eps: config.norm_eps,
+ output_scale_factor: config.mid_block_scale_factor,
+ cross_attn_dim: config.cross_attention_dim,
+ attn_num_head_channels: bl_attention_head_dim,
+ resnet_groups: Some(config.norm_num_groups),
+ use_linear_projection: config.use_linear_projection,
+ ..Default::default()
+ };
+ let mid_block = UNetMidBlock2DCrossAttn::new(
+ vs.pp("mid_block"),
+ bl_channels,
+ Some(time_embed_dim),
+ mid_cfg,
+ )?;
+
+ let vs_ub = vs.pp("up_blocks");
+ let up_blocks = (0..n_blocks)
+ .map(|i| {
+ let BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ } = config.blocks[n_blocks - 1 - i];
+
+ // Enable automatic attention slicing if the config sliced_attention_size is set to 0.
+ let sliced_attention_size = match config.sliced_attention_size {
+ Some(0) => Some(attention_head_dim / 2),
+ _ => config.sliced_attention_size,
+ };
+
+ let prev_out_channels = if i > 0 {
+ config.blocks[n_blocks - i].out_channels
+ } else {
+ bl_channels
+ };
+ let in_channels = {
+ let index = if i == n_blocks - 1 {
+ 0
+ } else {
+ n_blocks - i - 2
+ };
+ config.blocks[index].out_channels
+ };
+ let ub_cfg = UpBlock2DConfig {
+ num_layers: config.layers_per_block + 1,
+ resnet_eps: config.norm_eps,
+ resnet_groups: config.norm_num_groups,
+ add_upsample: i < n_blocks - 1,
+ ..Default::default()
+ };
+ if use_cross_attn {
+ let config = CrossAttnUpBlock2DConfig {
+ upblock: ub_cfg,
+ attn_num_head_channels: attention_head_dim,
+ cross_attention_dim: config.cross_attention_dim,
+ sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let block = CrossAttnUpBlock2D::new(
+ vs_ub.pp(&i.to_string()),
+ in_channels,
+ prev_out_channels,
+ out_channels,
+ Some(time_embed_dim),
+ config,
+ )?;
+ Ok(UNetUpBlock::CrossAttn(block))
+ } else {
+ let block = UpBlock2D::new(
+ vs_ub.pp(&i.to_string()),
+ in_channels,
+ prev_out_channels,
+ out_channels,
+ Some(time_embed_dim),
+ ub_cfg,
+ )?;
+ Ok(UNetUpBlock::Basic(block))
+ }
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ let conv_norm_out = nn::group_norm(
+ config.norm_num_groups,
+ b_channels,
+ config.norm_eps,
+ vs.pp("conv_norm_out"),
+ )?;
+ let conv_out = nn::conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?;
+ Ok(Self {
+ conv_in,
+ time_proj,
+ time_embedding,
+ down_blocks,
+ mid_block,
+ up_blocks,
+ conv_norm_out,
+ conv_out,
+ config,
+ })
+ }
+}
+
+impl UNet2DConditionModel {
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ timestep: f64,
+ encoder_hidden_states: &Tensor,
+ ) -> Result<Tensor> {
+ self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None)
+ }
+
+ pub fn forward_with_additional_residuals(
+ &self,
+ xs: &Tensor,
+ timestep: f64,
+ encoder_hidden_states: &Tensor,
+ down_block_additional_residuals: Option<&[Tensor]>,
+ mid_block_additional_residual: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let (bsize, _channels, height, width) = xs.dims4()?;
+ let device = xs.device();
+ let n_blocks = self.config.blocks.len();
+ let num_upsamplers = n_blocks - 1;
+ let default_overall_up_factor = 2usize.pow(num_upsamplers as u32);
+ let forward_upsample_size =
+ height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0;
+ // 0. center input if necessary
+ let xs = if self.config.center_input_sample {
+ ((xs * 2.0)? - 1.0)?
+ } else {
+ xs.clone()
+ };
+ // 1. time
+ let emb = (Tensor::ones(bsize, DType::F32, device)? * timestep)?;
+ let emb = self.time_proj.forward(&emb)?;
+ let emb = self.time_embedding.forward(&emb)?;
+ // 2. pre-process
+ let xs = self.conv_in.forward(&xs)?;
+ // 3. down
+ let mut down_block_res_xs = vec![xs.clone()];
+ let mut xs = xs;
+ for down_block in self.down_blocks.iter() {
+ let (_xs, res_xs) = match down_block {
+ UNetDownBlock::Basic(b) => b.forward(&xs, Some(&emb))?,
+ UNetDownBlock::CrossAttn(b) => {
+ b.forward(&xs, Some(&emb), Some(encoder_hidden_states))?
+ }
+ };
+ down_block_res_xs.extend(res_xs);
+ xs = _xs;
+ }
+
+ let new_down_block_res_xs =
+ if let Some(down_block_additional_residuals) = down_block_additional_residuals {
+ let mut v = vec![];
+ // A previous version of this code had a bug because of the addition being made
+ // in place via += hence modifying the input of the mid block.
+ for (i, residuals) in down_block_additional_residuals.iter().enumerate() {
+ v.push((&down_block_res_xs[i] + residuals)?)
+ }
+ v
+ } else {
+ down_block_res_xs
+ };
+ let mut down_block_res_xs = new_down_block_res_xs;
+
+ // 4. mid
+ let xs = self
+ .mid_block
+ .forward(&xs, Some(&emb), Some(encoder_hidden_states))?;
+ let xs = match mid_block_additional_residual {
+ None => xs,
+ Some(m) => (m + xs)?,
+ };
+ // 5. up
+ let mut xs = xs;
+ let mut upsample_size = None;
+ for (i, up_block) in self.up_blocks.iter().enumerate() {
+ let n_resnets = match up_block {
+ UNetUpBlock::Basic(b) => b.resnets.len(),
+ UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(),
+ };
+ let res_xs = down_block_res_xs.split_off(down_block_res_xs.len() - n_resnets);
+ if i < n_blocks - 1 && forward_upsample_size {
+ let (_, _, h, w) = down_block_res_xs.last().unwrap().dims4()?;
+ upsample_size = Some((h, w))
+ }
+ xs = match up_block {
+ UNetUpBlock::Basic(b) => b.forward(&xs, &res_xs, Some(&emb), upsample_size)?,
+ UNetUpBlock::CrossAttn(b) => b.forward(
+ &xs,
+ &res_xs,
+ Some(&emb),
+ upsample_size,
+ Some(encoder_hidden_states),
+ )?,
+ };
+ }
+ // 6. post-process
+ let xs = self.conv_norm_out.forward(&xs)?;
+ let xs = nn::ops::silu(&xs)?;
+ self.conv_out.forward(&xs)
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
new file mode 100644
index 00000000..82d5fad5
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
@@ -0,0 +1,808 @@
+#![allow(dead_code)]
+//! 2D UNet Building Blocks
+//!
+use crate::attention::{
+ AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
+};
+use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
+use candle::{Result, Tensor, D};
+use candle_nn as nn;
+
+#[derive(Debug)]
+struct Downsample2D {
+ conv: Option<nn::Conv2d>,
+ padding: usize,
+}
+
+impl Downsample2D {
+ fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ use_conv: bool,
+ out_channels: usize,
+ padding: usize,
+ ) -> Result<Self> {
+ let conv = if use_conv {
+ let config = nn::Conv2dConfig { stride: 2, padding };
+ let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
+ Some(conv)
+ } else {
+ None
+ };
+ Ok(Downsample2D { conv, padding })
+ }
+}
+
+impl Downsample2D {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ match &self.conv {
+ None => xs.avg_pool2d((2, 2), (2, 2)),
+ Some(conv) => {
+ if self.padding == 0 {
+ let xs = xs
+ .pad_with_zeros(D::Minus1, 0, 1)?
+ .pad_with_zeros(D::Minus2, 0, 1)?;
+ conv.forward(&xs)
+ } else {
+ conv.forward(xs)
+ }
+ }
+ }
+ }
+}
+
+// This does not support the conv-transpose mode.
+#[derive(Debug)]
+struct Upsample2D {
+ conv: nn::Conv2d,
+}
+
+impl Upsample2D {
+ fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> {
+ let config = nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
+ Ok(Self { conv })
+ }
+}
+
+impl Upsample2D {
+ fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
+ let xs = match size {
+ None => {
+ let (_bsize, _channels, h, w) = xs.dims4()?;
+ xs.upsample_nearest2d(2 * h, 2 * w)?
+ }
+ Some((h, w)) => xs.upsample_nearest2d(h, w)?,
+ };
+ self.conv.forward(&xs)
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct DownEncoderBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_downsample: bool,
+ pub downsample_padding: usize,
+}
+
+impl Default for DownEncoderBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_downsample: true,
+ downsample_padding: 1,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct DownEncoderBlock2D {
+ resnets: Vec<ResnetBlock2D>,
+ downsampler: Option<Downsample2D>,
+ pub config: DownEncoderBlock2DConfig,
+}
+
+impl DownEncoderBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: DownEncoderBlock2DConfig,
+ ) -> Result<Self> {
+ let resnets: Vec<_> = {
+ let vs = vs.pp("resnets");
+ let conv_cfg = ResnetBlock2DConfig {
+ eps: config.resnet_eps,
+ out_channels: Some(out_channels),
+ groups: config.resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels: None,
+ ..Default::default()
+ };
+ (0..(config.num_layers))
+ .map(|i| {
+ let in_channels = if i == 0 { in_channels } else { out_channels };
+ ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?
+ };
+ let downsampler = if config.add_downsample {
+ let downsample = Downsample2D::new(
+ vs.pp("downsamplers").pp("0"),
+ out_channels,
+ true,
+ out_channels,
+ config.downsample_padding,
+ )?;
+ Some(downsample)
+ } else {
+ None
+ };
+ Ok(Self {
+ resnets,
+ downsampler,
+ config,
+ })
+ }
+}
+
+impl DownEncoderBlock2D {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for resnet in self.resnets.iter() {
+ xs = resnet.forward(&xs, None)?
+ }
+ match &self.downsampler {
+ Some(downsampler) => downsampler.forward(&xs),
+ None => Ok(xs),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UpDecoderBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_upsample: bool,
+}
+
+impl Default for UpDecoderBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_upsample: true,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UpDecoderBlock2D {
+ resnets: Vec<ResnetBlock2D>,
+ upsampler: Option<Upsample2D>,
+ pub config: UpDecoderBlock2DConfig,
+}
+
+impl UpDecoderBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: UpDecoderBlock2DConfig,
+ ) -> Result<Self> {
+ let resnets: Vec<_> = {
+ let vs = vs.pp("resnets");
+ let conv_cfg = ResnetBlock2DConfig {
+ out_channels: Some(out_channels),
+ eps: config.resnet_eps,
+ groups: config.resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels: None,
+ ..Default::default()
+ };
+ (0..(config.num_layers))
+ .map(|i| {
+ let in_channels = if i == 0 { in_channels } else { out_channels };
+ ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?
+ };
+ let upsampler = if config.add_upsample {
+ let upsample =
+ Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
+ Some(upsample)
+ } else {
+ None
+ };
+ Ok(Self {
+ resnets,
+ upsampler,
+ config,
+ })
+ }
+}
+
+impl UpDecoderBlock2D {
+ pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for resnet in self.resnets.iter() {
+ xs = resnet.forward(&xs, None)?
+ }
+ match &self.upsampler {
+ Some(upsampler) => upsampler.forward(&xs, None),
+ None => Ok(xs),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UNetMidBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: Option<usize>,
+ pub attn_num_head_channels: Option<usize>,
+ // attention_type "default"
+ pub output_scale_factor: f64,
+}
+
+impl Default for UNetMidBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: Some(32),
+ attn_num_head_channels: Some(1),
+ output_scale_factor: 1.,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UNetMidBlock2D {
+ resnet: ResnetBlock2D,
+ attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
+ pub config: UNetMidBlock2DConfig,
+}
+
+impl UNetMidBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ temb_channels: Option<usize>,
+ config: UNetMidBlock2DConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let vs_attns = vs.pp("attentions");
+ let resnet_groups = config
+ .resnet_groups
+ .unwrap_or_else(|| usize::min(in_channels / 4, 32));
+ let resnet_cfg = ResnetBlock2DConfig {
+ eps: config.resnet_eps,
+ groups: resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels,
+ ..Default::default()
+ };
+ let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
+ let attn_cfg = AttentionBlockConfig {
+ num_head_channels: config.attn_num_head_channels,
+ num_groups: resnet_groups,
+ rescale_output_factor: config.output_scale_factor,
+ eps: config.resnet_eps,
+ };
+ let mut attn_resnets = vec![];
+ for index in 0..config.num_layers {
+ let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?;
+ let resnet = ResnetBlock2D::new(
+ vs_resnets.pp(&(index + 1).to_string()),
+ in_channels,
+ resnet_cfg,
+ )?;
+ attn_resnets.push((attn, resnet))
+ }
+ Ok(Self {
+ resnet,
+ attn_resnets,
+ config,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
+ let mut xs = self.resnet.forward(xs, temb)?;
+ for (attn, resnet) in self.attn_resnets.iter() {
+ xs = resnet.forward(&attn.forward(&xs)?, temb)?
+ }
+ Ok(xs)
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UNetMidBlock2DCrossAttnConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: Option<usize>,
+ pub attn_num_head_channels: usize,
+ // attention_type "default"
+ pub output_scale_factor: f64,
+ pub cross_attn_dim: usize,
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for UNetMidBlock2DCrossAttnConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: Some(32),
+ attn_num_head_channels: 1,
+ output_scale_factor: 1.,
+ cross_attn_dim: 1280,
+ sliced_attention_size: None, // Sliced attention disabled
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UNetMidBlock2DCrossAttn {
+ resnet: ResnetBlock2D,
+ attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
+ pub config: UNetMidBlock2DCrossAttnConfig,
+}
+
+impl UNetMidBlock2DCrossAttn {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ temb_channels: Option<usize>,
+ config: UNetMidBlock2DCrossAttnConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let vs_attns = vs.pp("attentions");
+ let resnet_groups = config
+ .resnet_groups
+ .unwrap_or_else(|| usize::min(in_channels / 4, 32));
+ let resnet_cfg = ResnetBlock2DConfig {
+ eps: config.resnet_eps,
+ groups: resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels,
+ ..Default::default()
+ };
+ let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
+ let n_heads = config.attn_num_head_channels;
+ let attn_cfg = SpatialTransformerConfig {
+ depth: 1,
+ num_groups: resnet_groups,
+ context_dim: Some(config.cross_attn_dim),
+ sliced_attention_size: config.sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let mut attn_resnets = vec![];
+ for index in 0..config.num_layers {
+ let attn = SpatialTransformer::new(
+ vs_attns.pp(&index.to_string()),
+ in_channels,
+ n_heads,
+ in_channels / n_heads,
+ attn_cfg,
+ )?;
+ let resnet = ResnetBlock2D::new(
+ vs_resnets.pp(&(index + 1).to_string()),
+ in_channels,
+ resnet_cfg,
+ )?;
+ attn_resnets.push((attn, resnet))
+ }
+ Ok(Self {
+ resnet,
+ attn_resnets,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ temb: Option<&Tensor>,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let mut xs = self.resnet.forward(xs, temb)?;
+ for (attn, resnet) in self.attn_resnets.iter() {
+ xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
+ }
+ Ok(xs)
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct DownBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ // resnet_time_scale_shift: "default"
+ // resnet_act_fn: "swish"
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_downsample: bool,
+ pub downsample_padding: usize,
+}
+
+impl Default for DownBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_downsample: true,
+ downsample_padding: 1,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct DownBlock2D {
+ resnets: Vec<ResnetBlock2D>,
+ downsampler: Option<Downsample2D>,
+ pub config: DownBlock2DConfig,
+}
+
+impl DownBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ config: DownBlock2DConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let resnet_cfg = ResnetBlock2DConfig {
+ out_channels: Some(out_channels),
+ eps: config.resnet_eps,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels,
+ ..Default::default()
+ };
+ let resnets = (0..config.num_layers)
+ .map(|i| {
+ let in_channels = if i == 0 { in_channels } else { out_channels };
+ ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let downsampler = if config.add_downsample {
+ let downsampler = Downsample2D::new(
+ vs.pp("downsamplers").pp("0"),
+ out_channels,
+ true,
+ out_channels,
+ config.downsample_padding,
+ )?;
+ Some(downsampler)
+ } else {
+ None
+ };
+ Ok(Self {
+ resnets,
+ downsampler,
+ config,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
+ let mut xs = xs.clone();
+ let mut output_states = vec![];
+ for resnet in self.resnets.iter() {
+ xs = resnet.forward(&xs, temb)?;
+ output_states.push(xs.clone());
+ }
+ let xs = match &self.downsampler {
+ Some(downsampler) => {
+ let xs = downsampler.forward(&xs)?;
+ output_states.push(xs.clone());
+ xs
+ }
+ None => xs,
+ };
+ Ok((xs, output_states))
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct CrossAttnDownBlock2DConfig {
+ pub downblock: DownBlock2DConfig,
+ pub attn_num_head_channels: usize,
+ pub cross_attention_dim: usize,
+ // attention_type: "default"
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for CrossAttnDownBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ downblock: Default::default(),
+ attn_num_head_channels: 1,
+ cross_attention_dim: 1280,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct CrossAttnDownBlock2D {
+ downblock: DownBlock2D,
+ attentions: Vec<SpatialTransformer>,
+ pub config: CrossAttnDownBlock2DConfig,
+}
+
+impl CrossAttnDownBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ config: CrossAttnDownBlock2DConfig,
+ ) -> Result<Self> {
+ let downblock = DownBlock2D::new(
+ vs.clone(),
+ in_channels,
+ out_channels,
+ temb_channels,
+ config.downblock,
+ )?;
+ let n_heads = config.attn_num_head_channels;
+ let cfg = SpatialTransformerConfig {
+ depth: 1,
+ context_dim: Some(config.cross_attention_dim),
+ num_groups: config.downblock.resnet_groups,
+ sliced_attention_size: config.sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let vs_attn = vs.pp("attentions");
+ let attentions = (0..config.downblock.num_layers)
+ .map(|i| {
+ SpatialTransformer::new(
+ vs_attn.pp(&i.to_string()),
+ out_channels,
+ n_heads,
+ out_channels / n_heads,
+ cfg,
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Self {
+ downblock,
+ attentions,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ temb: Option<&Tensor>,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<(Tensor, Vec<Tensor>)> {
+ let mut output_states = vec![];
+ let mut xs = xs.clone();
+ for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
+ xs = resnet.forward(&xs, temb)?;
+ xs = attn.forward(&xs, encoder_hidden_states)?;
+ output_states.push(xs.clone());
+ }
+ let xs = match &self.downblock.downsampler {
+ Some(downsampler) => {
+ let xs = downsampler.forward(&xs)?;
+ output_states.push(xs.clone());
+ xs
+ }
+ None => xs,
+ };
+ Ok((xs, output_states))
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UpBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ // resnet_time_scale_shift: "default"
+ // resnet_act_fn: "swish"
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_upsample: bool,
+}
+
+impl Default for UpBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_upsample: true,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UpBlock2D {
+ pub resnets: Vec<ResnetBlock2D>,
+ upsampler: Option<Upsample2D>,
+ pub config: UpBlock2DConfig,
+}
+
+impl UpBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ prev_output_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ config: UpBlock2DConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let resnet_cfg = ResnetBlock2DConfig {
+ out_channels: Some(out_channels),
+ temb_channels,
+ eps: config.resnet_eps,
+ output_scale_factor: config.output_scale_factor,
+ ..Default::default()
+ };
+ let resnets = (0..config.num_layers)
+ .map(|i| {
+ let res_skip_channels = if i == config.num_layers - 1 {
+ in_channels
+ } else {
+ out_channels
+ };
+ let resnet_in_channels = if i == 0 {
+ prev_output_channels
+ } else {
+ out_channels
+ };
+ let in_channels = resnet_in_channels + res_skip_channels;
+ ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let upsampler = if config.add_upsample {
+ let upsampler =
+ Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
+ Some(upsampler)
+ } else {
+ None
+ };
+ Ok(Self {
+ resnets,
+ upsampler,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ res_xs: &[Tensor],
+ temb: Option<&Tensor>,
+ upsample_size: Option<(usize, usize)>,
+ ) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for (index, resnet) in self.resnets.iter().enumerate() {
+ xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
+ xs = resnet.forward(&xs, temb)?;
+ }
+ match &self.upsampler {
+ Some(upsampler) => upsampler.forward(&xs, upsample_size),
+ None => Ok(xs),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct CrossAttnUpBlock2DConfig {
+ pub upblock: UpBlock2DConfig,
+ pub attn_num_head_channels: usize,
+ pub cross_attention_dim: usize,
+ // attention_type: "default"
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for CrossAttnUpBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ upblock: Default::default(),
+ attn_num_head_channels: 1,
+ cross_attention_dim: 1280,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct CrossAttnUpBlock2D {
+ pub upblock: UpBlock2D,
+ pub attentions: Vec<SpatialTransformer>,
+ pub config: CrossAttnUpBlock2DConfig,
+}
+
+impl CrossAttnUpBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ prev_output_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ config: CrossAttnUpBlock2DConfig,
+ ) -> Result<Self> {
+ let upblock = UpBlock2D::new(
+ vs.clone(),
+ in_channels,
+ prev_output_channels,
+ out_channels,
+ temb_channels,
+ config.upblock,
+ )?;
+ let n_heads = config.attn_num_head_channels;
+ let cfg = SpatialTransformerConfig {
+ depth: 1,
+ context_dim: Some(config.cross_attention_dim),
+ num_groups: config.upblock.resnet_groups,
+ sliced_attention_size: config.sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let vs_attn = vs.pp("attentions");
+ let attentions = (0..config.upblock.num_layers)
+ .map(|i| {
+ SpatialTransformer::new(
+ vs_attn.pp(&i.to_string()),
+ out_channels,
+ n_heads,
+ out_channels / n_heads,
+ cfg,
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Self {
+ upblock,
+ attentions,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ res_xs: &[Tensor],
+ temb: Option<&Tensor>,
+ upsample_size: Option<(usize, usize)>,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for (index, resnet) in self.upblock.resnets.iter().enumerate() {
+ xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
+ xs = resnet.forward(&xs, temb)?;
+ xs = self.attentions[index].forward(&xs, encoder_hidden_states)?;
+ }
+ match &self.upblock.upsampler {
+ Some(upsampler) => upsampler.forward(&xs, upsample_size),
+ None => Ok(xs),
+ }
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs
new file mode 100644
index 00000000..ef4dd956
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/utils.rs
@@ -0,0 +1,31 @@
+use candle::{Device, Result, Tensor};
+
+pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
+ if steps < 1 {
+ candle::bail!("cannot use linspace with steps {steps} <= 1")
+ }
+ let delta = (stop - start) / (steps - 1) as f64;
+ let vs = (0..steps)
+ .map(|step| start + step as f64 * delta)
+ .collect::<Vec<_>>();
+ Tensor::from_vec(vs, steps, &Device::Cpu)
+}
+
+/// Saves an image to disk using the image crate, this expects an input with shape
+/// (c, width, height).
+pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
+ let p = p.as_ref();
+ let (channel, width, height) = img.dims3()?;
+ if channel != 3 {
+ candle::bail!("save_image expects an input of shape (3, width, height)")
+ }
+ let img = img.transpose(0, 1)?.t()?.flatten_all()?;
+ let pixels = img.to_vec1::<u8>()?;
+ let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
+ match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
+ Some(image) => image,
+ None => candle::bail!("error saving image {p:?}"),
+ };
+ image.save(p).map_err(candle::Error::wrap)?;
+ Ok(())
+}
diff --git a/candle-examples/examples/stable-diffusion/vae.rs b/candle-examples/examples/stable-diffusion/vae.rs
new file mode 100644
index 00000000..7a10d932
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion/vae.rs
@@ -0,0 +1,378 @@
+#![allow(dead_code)]
+//! # Variational Auto-Encoder (VAE) Models.
+//!
+//! Auto-encoder models compress their input to a usually smaller latent space
+//! before expanding it back to its original shape. This results in the latent values
+//! compressing the original information.
+use crate::unet_2d_blocks::{
+ DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig,
+ UpDecoderBlock2D, UpDecoderBlock2DConfig,
+};
+use candle::{Result, Tensor};
+use candle_nn as nn;
+
+#[derive(Debug, Clone)]
+struct EncoderConfig {
+ // down_block_types: DownEncoderBlock2D
+ block_out_channels: Vec<usize>,
+ layers_per_block: usize,
+ norm_num_groups: usize,
+ double_z: bool,
+}
+
+impl Default for EncoderConfig {
+ fn default() -> Self {
+ Self {
+ block_out_channels: vec![64],
+ layers_per_block: 2,
+ norm_num_groups: 32,
+ double_z: true,
+ }
+ }
+}
+
+#[derive(Debug)]
+struct Encoder {
+ conv_in: nn::Conv2d,
+ down_blocks: Vec<DownEncoderBlock2D>,
+ mid_block: UNetMidBlock2D,
+ conv_norm_out: nn::GroupNorm,
+ conv_out: nn::Conv2d,
+ #[allow(dead_code)]
+ config: EncoderConfig,
+}
+
+impl Encoder {
+ fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: EncoderConfig,
+ ) -> Result<Self> {
+ let conv_cfg = nn::Conv2dConfig {
+ stride: 1,
+ padding: 1,
+ };
+ let conv_in = nn::conv2d(
+ in_channels,
+ config.block_out_channels[0],
+ 3,
+ conv_cfg,
+ vs.pp("conv_in"),
+ )?;
+ let mut down_blocks = vec![];
+ let vs_down_blocks = vs.pp("down_blocks");
+ for index in 0..config.block_out_channels.len() {
+ let out_channels = config.block_out_channels[index];
+ let in_channels = if index > 0 {
+ config.block_out_channels[index - 1]
+ } else {
+ config.block_out_channels[0]
+ };
+ let is_final = index + 1 == config.block_out_channels.len();
+ let cfg = DownEncoderBlock2DConfig {
+ num_layers: config.layers_per_block,
+ resnet_eps: 1e-6,
+ resnet_groups: config.norm_num_groups,
+ add_downsample: !is_final,
+ downsample_padding: 0,
+ ..Default::default()
+ };
+ let down_block = DownEncoderBlock2D::new(
+ vs_down_blocks.pp(&index.to_string()),
+ in_channels,
+ out_channels,
+ cfg,
+ )?;
+ down_blocks.push(down_block)
+ }
+ let last_block_out_channels = *config.block_out_channels.last().unwrap();
+ let mid_cfg = UNetMidBlock2DConfig {
+ resnet_eps: 1e-6,
+ output_scale_factor: 1.,
+ attn_num_head_channels: None,
+ resnet_groups: Some(config.norm_num_groups),
+ ..Default::default()
+ };
+ let mid_block =
+ UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
+ let conv_norm_out = nn::group_norm(
+ config.norm_num_groups,
+ last_block_out_channels,
+ 1e-6,
+ vs.pp("conv_norm_out"),
+ )?;
+ let conv_out_channels = if config.double_z {
+ 2 * out_channels
+ } else {
+ out_channels
+ };
+ let conv_cfg = nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let conv_out = nn::conv2d(
+ last_block_out_channels,
+ conv_out_channels,
+ 3,
+ conv_cfg,
+ vs.pp("conv_out"),
+ )?;
+ Ok(Self {
+ conv_in,
+ down_blocks,
+ mid_block,
+ conv_norm_out,
+ conv_out,
+ config,
+ })
+ }
+}
+
+impl Encoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = self.conv_in.forward(xs)?;
+ for down_block in self.down_blocks.iter() {
+ xs = down_block.forward(&xs)?
+ }
+ let xs = self.mid_block.forward(&xs, None)?;
+ let xs = self.conv_norm_out.forward(&xs)?;
+ let xs = nn::ops::silu(&xs)?;
+ self.conv_out.forward(&xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct DecoderConfig {
+ // up_block_types: UpDecoderBlock2D
+ block_out_channels: Vec<usize>,
+ layers_per_block: usize,
+ norm_num_groups: usize,
+}
+
+impl Default for DecoderConfig {
+ fn default() -> Self {
+ Self {
+ block_out_channels: vec![64],
+ layers_per_block: 2,
+ norm_num_groups: 32,
+ }
+ }
+}
+
+#[derive(Debug)]
+struct Decoder {
+ conv_in: nn::Conv2d,
+ up_blocks: Vec<UpDecoderBlock2D>,
+ mid_block: UNetMidBlock2D,
+ conv_norm_out: nn::GroupNorm,
+ conv_out: nn::Conv2d,
+ #[allow(dead_code)]
+ config: DecoderConfig,
+}
+
+impl Decoder {
+ fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: DecoderConfig,
+ ) -> Result<Self> {
+ let n_block_out_channels = config.block_out_channels.len();
+ let last_block_out_channels = *config.block_out_channels.last().unwrap();
+ let conv_cfg = nn::Conv2dConfig {
+ stride: 1,
+ padding: 1,
+ };
+ let conv_in = nn::conv2d(
+ in_channels,
+ last_block_out_channels,
+ 3,
+ conv_cfg,
+ vs.pp("conv_in"),
+ )?;
+ let mid_cfg = UNetMidBlock2DConfig {
+ resnet_eps: 1e-6,
+ output_scale_factor: 1.,
+ attn_num_head_channels: None,
+ resnet_groups: Some(config.norm_num_groups),
+ ..Default::default()
+ };
+ let mid_block =
+ UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
+ let mut up_blocks = vec![];
+ let vs_up_blocks = vs.pp("up_blocks");
+ let reversed_block_out_channels: Vec<_> =
+ config.block_out_channels.iter().copied().rev().collect();
+ for index in 0..n_block_out_channels {
+ let out_channels = reversed_block_out_channels[index];
+ let in_channels = if index > 0 {
+ reversed_block_out_channels[index - 1]
+ } else {
+ reversed_block_out_channels[0]
+ };
+ let is_final = index + 1 == n_block_out_channels;
+ let cfg = UpDecoderBlock2DConfig {
+ num_layers: config.layers_per_block + 1,
+ resnet_eps: 1e-6,
+ resnet_groups: config.norm_num_groups,
+ add_upsample: !is_final,
+ ..Default::default()
+ };
+ let up_block = UpDecoderBlock2D::new(
+ vs_up_blocks.pp(&index.to_string()),
+ in_channels,
+ out_channels,
+ cfg,
+ )?;
+ up_blocks.push(up_block)
+ }
+ let conv_norm_out = nn::group_norm(
+ config.norm_num_groups,
+ config.block_out_channels[0],
+ 1e-6,
+ vs.pp("conv_norm_out"),
+ )?;
+ let conv_cfg = nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let conv_out = nn::conv2d(
+ config.block_out_channels[0],
+ out_channels,
+ 3,
+ conv_cfg,
+ vs.pp("conv_out"),
+ )?;
+ Ok(Self {
+ conv_in,
+ up_blocks,
+ mid_block,
+ conv_norm_out,
+ conv_out,
+ config,
+ })
+ }
+}
+
+impl Decoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = self.mid_block.forward(&self.conv_in.forward(xs)?, None)?;
+ for up_block in self.up_blocks.iter() {
+ xs = up_block.forward(&xs)?
+ }
+ let xs = self.conv_norm_out.forward(&xs)?;
+ let xs = nn::ops::silu(&xs)?;
+ self.conv_out.forward(&xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct AutoEncoderKLConfig {
+ pub block_out_channels: Vec<usize>,
+ pub layers_per_block: usize,
+ pub latent_channels: usize,
+ pub norm_num_groups: usize,
+}
+
+impl Default for AutoEncoderKLConfig {
+ fn default() -> Self {
+ Self {
+ block_out_channels: vec![64],
+ layers_per_block: 1,
+ latent_channels: 4,
+ norm_num_groups: 32,
+ }
+ }
+}
+
+pub struct DiagonalGaussianDistribution {
+ mean: Tensor,
+ std: Tensor,
+}
+
+impl DiagonalGaussianDistribution {
+ pub fn new(parameters: &Tensor) -> Result<Self> {
+ let mut parameters = parameters.chunk(2, 1)?.into_iter();
+ let mean = parameters.next().unwrap();
+ let logvar = parameters.next().unwrap();
+ let std = (logvar * 0.5)?.exp()?;
+ Ok(DiagonalGaussianDistribution { mean, std })
+ }
+
+ pub fn sample(&self) -> Result<Tensor> {
+ let sample = Tensor::randn(0., 1f32, self.mean.shape(), self.mean.device());
+ &self.mean + &self.std * sample
+ }
+}
+
+// https://github.com/huggingface/diffusers/blob/970e30606c2944e3286f56e8eb6d3dc6d1eb85f7/src/diffusers/models/vae.py#L485
+// This implementation is specific to the config used in stable-diffusion-v1-5
+// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
+#[derive(Debug)]
+pub struct AutoEncoderKL {
+ encoder: Encoder,
+ decoder: Decoder,
+ quant_conv: nn::Conv2d,
+ post_quant_conv: nn::Conv2d,
+ pub config: AutoEncoderKLConfig,
+}
+
+impl AutoEncoderKL {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: AutoEncoderKLConfig,
+ ) -> Result<Self> {
+ let latent_channels = config.latent_channels;
+ let encoder_cfg = EncoderConfig {
+ block_out_channels: config.block_out_channels.clone(),
+ layers_per_block: config.layers_per_block,
+ norm_num_groups: config.norm_num_groups,
+ double_z: true,
+ };
+ let encoder = Encoder::new(vs.pp("encoder"), in_channels, latent_channels, encoder_cfg)?;
+ let decoder_cfg = DecoderConfig {
+ block_out_channels: config.block_out_channels.clone(),
+ layers_per_block: config.layers_per_block,
+ norm_num_groups: config.norm_num_groups,
+ };
+ let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?;
+ let conv_cfg = Default::default();
+ let quant_conv = nn::conv2d(
+ 2 * latent_channels,
+ 2 * latent_channels,
+ 1,
+ conv_cfg,
+ vs.pp("quant_conv"),
+ )?;
+ let post_quant_conv = nn::conv2d(
+ latent_channels,
+ latent_channels,
+ 1,
+ conv_cfg,
+ vs.pp("post_quant_conv"),
+ )?;
+ Ok(Self {
+ encoder,
+ decoder,
+ quant_conv,
+ post_quant_conv,
+ config,
+ })
+ }
+
+ /// Returns the distribution in the latent space.
+ pub fn encode(&self, xs: &Tensor) -> Result<DiagonalGaussianDistribution> {
+ let xs = self.encoder.forward(xs)?;
+ let parameters = self.quant_conv.forward(&xs)?;
+ DiagonalGaussianDistribution::new(&parameters)
+ }
+
+ /// Takes as input some sampled values.
+ pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.post_quant_conv.forward(xs)?;
+ self.decoder.forward(&xs)
+ }
+}