summaryrefslogtreecommitdiff
path: root/candle-transformers/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src')
-rw-r--r--candle-transformers/src/generation/mod.rs79
-rw-r--r--candle-transformers/src/lib.rs1
-rw-r--r--candle-transformers/src/models/bert.rs568
-rw-r--r--candle-transformers/src/models/bigcode.rs359
-rw-r--r--candle-transformers/src/models/dinov2.rs279
-rw-r--r--candle-transformers/src/models/efficientnet.rs331
-rw-r--r--candle-transformers/src/models/falcon.rs484
-rw-r--r--candle-transformers/src/models/llama.rs446
-rw-r--r--candle-transformers/src/models/mod.rs14
-rw-r--r--candle-transformers/src/models/quantized_llama.rs371
-rw-r--r--candle-transformers/src/models/quantized_t5.rs884
-rw-r--r--candle-transformers/src/models/segment_anything/image_encoder.rs483
-rw-r--r--candle-transformers/src/models/segment_anything/mask_decoder.rs239
-rw-r--r--candle-transformers/src/models/segment_anything/mod.rs100
-rw-r--r--candle-transformers/src/models/segment_anything/prompt_encoder.rs239
-rw-r--r--candle-transformers/src/models/segment_anything/sam.rs433
-rw-r--r--candle-transformers/src/models/segment_anything/tiny_vit.rs633
-rw-r--r--candle-transformers/src/models/segment_anything/transformer.rs221
-rw-r--r--candle-transformers/src/models/stable_diffusion/attention.rs547
-rw-r--r--candle-transformers/src/models/stable_diffusion/clip.rs389
-rw-r--r--candle-transformers/src/models/stable_diffusion/ddim.rs180
-rw-r--r--candle-transformers/src/models/stable_diffusion/ddpm.rs205
-rw-r--r--candle-transformers/src/models/stable_diffusion/embeddings.rs65
-rw-r--r--candle-transformers/src/models/stable_diffusion/mod.rs303
-rw-r--r--candle-transformers/src/models/stable_diffusion/resnet.rs138
-rw-r--r--candle-transformers/src/models/stable_diffusion/schedulers.rs45
-rw-r--r--candle-transformers/src/models/stable_diffusion/unet_2d.rs401
-rw-r--r--candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs868
-rw-r--r--candle-transformers/src/models/stable_diffusion/utils.rs39
-rw-r--r--candle-transformers/src/models/stable_diffusion/vae.rs380
-rw-r--r--candle-transformers/src/models/t5.rs841
-rw-r--r--candle-transformers/src/models/whisper/audio.rs210
-rw-r--r--candle-transformers/src/models/whisper/mod.rs26
-rw-r--r--candle-transformers/src/models/whisper/model.rs416
-rw-r--r--candle-transformers/src/models/wuerstchen/attention_processor.rs118
-rw-r--r--candle-transformers/src/models/wuerstchen/common.rs203
-rw-r--r--candle-transformers/src/models/wuerstchen/ddpm.rs103
-rw-r--r--candle-transformers/src/models/wuerstchen/diffnext.rs396
-rw-r--r--candle-transformers/src/models/wuerstchen/mod.rs6
-rw-r--r--candle-transformers/src/models/wuerstchen/paella_vq.rs211
-rw-r--r--candle-transformers/src/models/wuerstchen/prior.rs103
-rw-r--r--candle-transformers/src/object_detection.rs52
42 files changed, 12392 insertions, 17 deletions
diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs
index b1d20168..b1a567c3 100644
--- a/candle-transformers/src/generation/mod.rs
+++ b/candle-transformers/src/generation/mod.rs
@@ -1,35 +1,82 @@
-use candle::{DType, Error, Result, Tensor, D};
+use candle::{DType, Error, Result, Tensor};
use rand::{distributions::Distribution, SeedableRng};
pub struct LogitsProcessor {
rng: rand::rngs::StdRng,
temperature: Option<f64>,
+ top_p: Option<f64>,
}
impl LogitsProcessor {
- pub fn new(seed: u64, temperature: Option<f64>) -> Self {
+ pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self {
+ let temperature = if temperature.map_or(true, |v| v < 1e-7) {
+ None
+ } else {
+ temperature
+ };
Self {
rng: rand::rngs::StdRng::seed_from_u64(seed),
temperature,
+ top_p,
+ }
+ }
+
+ fn sample_argmax(&mut self, logits: Tensor) -> Result<u32> {
+ let logits_v: Vec<f32> = logits.to_vec1()?;
+ let next_token = logits_v
+ .iter()
+ .enumerate()
+ .max_by(|(_, u), (_, v)| u.total_cmp(v))
+ .map(|(i, _)| i as u32)
+ .unwrap();
+ Ok(next_token)
+ }
+
+ fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
+ let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
+ let next_token = distr.sample(&mut self.rng) as u32;
+ Ok(next_token)
+ }
+
+ fn sample_topp(&mut self, prs: &mut Vec<f32>, top_p: f32) -> Result<u32> {
+ // top-p sampling (or "nucleus sampling") samples from the smallest set of
+ // tokens that exceed probability top_p. This way we never sample tokens that
+ // have very low probabilities and are less likely to go "off the rails".
+ let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
+
+ // Sort by descending probability.
+ argsort_indices.sort_by(|&i, &j| prs[j].partial_cmp(&prs[i]).unwrap());
+
+ // Clamp smaller probabilities to zero.
+ let mut cumsum = 0.;
+ for index in &argsort_indices {
+ if cumsum >= top_p {
+ prs[*index] = 0.0;
+ } else {
+ cumsum += prs[*index];
+ }
}
+ // Sample with clamped probabilities.
+ self.sample_multinomial(prs)
}
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?;
- let temperature = self.temperature.unwrap_or(0.);
- let next_token = if temperature > 0. {
- let prs = candle_nn::ops::softmax(&(&logits / temperature)?, D::Minus1)?;
- let prs: Vec<f32> = prs.to_vec1()?;
- let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
- distr.sample(&mut self.rng) as u32
- } else {
- let logits_v: Vec<f32> = logits.to_vec1()?;
- logits_v
- .iter()
- .enumerate()
- .max_by(|(_, u), (_, v)| u.total_cmp(v))
- .map(|(i, _)| i as u32)
- .unwrap()
+ let next_token = match self.temperature {
+ None => self.sample_argmax(logits)?,
+ Some(temperature) => {
+ let logits = &(&logits / temperature)?;
+ let prs = candle_nn::ops::softmax_last_dim(logits)?;
+ let mut prs: Vec<f32> = prs.to_vec1()?;
+ let top_p = self.top_p.unwrap_or(1.);
+ if top_p <= 0.0 || top_p >= 1.0 {
+ // simply sample from the predicted probability distribution
+ self.sample_multinomial(&prs)?
+ } else {
+ // top-p (nucleus) sampling, clamping the least likely tokens to zero
+ self.sample_topp(&mut prs, top_p as f32)?
+ }
+ }
};
Ok(next_token)
}
diff --git a/candle-transformers/src/lib.rs b/candle-transformers/src/lib.rs
index a8890dc8..b83e5056 100644
--- a/candle-transformers/src/lib.rs
+++ b/candle-transformers/src/lib.rs
@@ -1,4 +1,5 @@
pub mod generation;
pub mod models;
+pub mod object_detection;
pub mod pipelines;
pub mod utils;
diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs
new file mode 100644
index 00000000..3f164a3a
--- /dev/null
+++ b/candle-transformers/src/models/bert.rs
@@ -0,0 +1,568 @@
+use candle::{DType, Device, Result, Tensor};
+use candle_nn::{Embedding, Module, VarBuilder};
+use serde::Deserialize;
+
+pub const DTYPE: DType = DType::F32;
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
+#[serde(rename_all = "lowercase")]
+enum HiddenAct {
+ Gelu,
+ Relu,
+}
+
+struct HiddenActLayer {
+ act: HiddenAct,
+ span: tracing::Span,
+}
+
+impl HiddenActLayer {
+ fn new(act: HiddenAct) -> Self {
+ let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
+ Self { act, span }
+ }
+
+ fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
+ let _enter = self.span.enter();
+ match self.act {
+ // TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some
+ // small numerical difference.
+ // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
+ HiddenAct::Gelu => xs.gelu(),
+ HiddenAct::Relu => xs.relu(),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Linear {
+ weight: Tensor,
+ bias: Option<Tensor>,
+ span: tracing::Span,
+}
+
+impl Linear {
+ pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ Self { weight, bias, span }
+ }
+
+ pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
+ let _enter = self.span.enter();
+ let w = match x.dims() {
+ &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
+ _ => self.weight.t()?,
+ };
+ let x = x.matmul(&w)?;
+ match &self.bias {
+ None => Ok(x),
+ Some(bias) => x.broadcast_add(bias),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct LayerNorm {
+ weight: Tensor,
+ bias: Tensor,
+ eps: f64,
+ span: tracing::Span,
+}
+
+impl LayerNorm {
+ pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
+ let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
+ Self {
+ weight,
+ bias,
+ eps,
+ span,
+ }
+ }
+
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let x_dtype = x.dtype();
+ let internal_dtype = match x_dtype {
+ DType::F16 | DType::BF16 => DType::F32,
+ d => d,
+ };
+ let (_bsize, _seq_len, hidden_size) = x.dims3()?;
+ let x = x.to_dtype(internal_dtype)?;
+ let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
+ let x = x.broadcast_sub(&mean_x)?;
+ let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
+ let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
+ let x = x_normed
+ .to_dtype(x_dtype)?
+ .broadcast_mul(&self.weight)?
+ .broadcast_add(&self.bias)?;
+ Ok(x)
+ }
+}
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
+#[serde(rename_all = "lowercase")]
+enum PositionEmbeddingType {
+ #[default]
+ Absolute,
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
+#[derive(Debug, Clone, PartialEq, Deserialize)]
+pub struct Config {
+ vocab_size: usize,
+ hidden_size: usize,
+ num_hidden_layers: usize,
+ num_attention_heads: usize,
+ intermediate_size: usize,
+ hidden_act: HiddenAct,
+ hidden_dropout_prob: f64,
+ max_position_embeddings: usize,
+ type_vocab_size: usize,
+ initializer_range: f64,
+ layer_norm_eps: f64,
+ pad_token_id: usize,
+ #[serde(default)]
+ position_embedding_type: PositionEmbeddingType,
+ #[serde(default)]
+ use_cache: bool,
+ classifier_dropout: Option<f64>,
+ model_type: Option<String>,
+}
+
+impl Default for Config {
+ fn default() -> Self {
+ Self {
+ vocab_size: 30522,
+ hidden_size: 768,
+ num_hidden_layers: 12,
+ num_attention_heads: 12,
+ intermediate_size: 3072,
+ hidden_act: HiddenAct::Gelu,
+ hidden_dropout_prob: 0.1,
+ max_position_embeddings: 512,
+ type_vocab_size: 2,
+ initializer_range: 0.02,
+ layer_norm_eps: 1e-12,
+ pad_token_id: 0,
+ position_embedding_type: PositionEmbeddingType::Absolute,
+ use_cache: true,
+ classifier_dropout: None,
+ model_type: Some("bert".to_string()),
+ }
+ }
+}
+
+impl Config {
+ fn _all_mini_lm_l6_v2() -> Self {
+ // https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json
+ Self {
+ vocab_size: 30522,
+ hidden_size: 384,
+ num_hidden_layers: 6,
+ num_attention_heads: 12,
+ intermediate_size: 1536,
+ hidden_act: HiddenAct::Gelu,
+ hidden_dropout_prob: 0.1,
+ max_position_embeddings: 512,
+ type_vocab_size: 2,
+ initializer_range: 0.02,
+ layer_norm_eps: 1e-12,
+ pad_token_id: 0,
+ position_embedding_type: PositionEmbeddingType::Absolute,
+ use_cache: true,
+ classifier_dropout: None,
+ model_type: Some("bert".to_string()),
+ }
+ }
+}
+
+fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
+ let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
+ Ok(Embedding::new(embeddings, hidden_size))
+}
+
+fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
+ let weight = vb.get((size2, size1), "weight")?;
+ let bias = vb.get(size2, "bias")?;
+ Ok(Linear::new(weight, Some(bias)))
+}
+
+struct Dropout {
+ #[allow(dead_code)]
+ pr: f64,
+}
+
+impl Dropout {
+ fn new(pr: f64) -> Self {
+ Self { pr }
+ }
+
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ // TODO
+ Ok(x.clone())
+ }
+}
+
+fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
+ let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
+ (Ok(weight), Ok(bias)) => (weight, bias),
+ (Err(err), _) | (_, Err(err)) => {
+ if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
+ (weight, bias)
+ } else {
+ return Err(err);
+ }
+ }
+ };
+ Ok(LayerNorm::new(weight, bias, eps))
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
+struct BertEmbeddings {
+ word_embeddings: Embedding,
+ position_embeddings: Option<Embedding>,
+ token_type_embeddings: Embedding,
+ layer_norm: LayerNorm,
+ dropout: Dropout,
+ span: tracing::Span,
+}
+
+impl BertEmbeddings {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let word_embeddings = embedding(
+ config.vocab_size,
+ config.hidden_size,
+ vb.pp("word_embeddings"),
+ )?;
+ let position_embeddings = embedding(
+ config.max_position_embeddings,
+ config.hidden_size,
+ vb.pp("position_embeddings"),
+ )?;
+ let token_type_embeddings = embedding(
+ config.type_vocab_size,
+ config.hidden_size,
+ vb.pp("token_type_embeddings"),
+ )?;
+ let layer_norm = layer_norm(
+ config.hidden_size,
+ config.layer_norm_eps,
+ vb.pp("LayerNorm"),
+ )?;
+ Ok(Self {
+ word_embeddings,
+ position_embeddings: Some(position_embeddings),
+ token_type_embeddings,
+ layer_norm,
+ dropout: Dropout::new(config.hidden_dropout_prob),
+ span: tracing::span!(tracing::Level::TRACE, "embeddings"),
+ })
+ }
+
+ fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let (_bsize, seq_len) = input_ids.dims2()?;
+ let input_embeddings = self.word_embeddings.forward(input_ids)?;
+ let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
+ let mut embeddings = (&input_embeddings + token_type_embeddings)?;
+ if let Some(position_embeddings) = &self.position_embeddings {
+ // TODO: Proper absolute positions?
+ let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
+ let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
+ embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
+ }
+ let embeddings = self.layer_norm.forward(&embeddings)?;
+ let embeddings = self.dropout.forward(&embeddings)?;
+ Ok(embeddings)
+ }
+}
+
+struct BertSelfAttention {
+ query: Linear,
+ key: Linear,
+ value: Linear,
+ dropout: Dropout,
+ num_attention_heads: usize,
+ attention_head_size: usize,
+ span: tracing::Span,
+ span_softmax: tracing::Span,
+}
+
+impl BertSelfAttention {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let attention_head_size = config.hidden_size / config.num_attention_heads;
+ let all_head_size = config.num_attention_heads * attention_head_size;
+ let dropout = Dropout::new(config.hidden_dropout_prob);
+ let hidden_size = config.hidden_size;
+ let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
+ let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
+ let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
+ Ok(Self {
+ query,
+ key,
+ value,
+ dropout,
+ num_attention_heads: config.num_attention_heads,
+ attention_head_size,
+ span: tracing::span!(tracing::Level::TRACE, "self-attn"),
+ span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
+ })
+ }
+
+ fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut new_x_shape = xs.dims().to_vec();
+ new_x_shape.pop();
+ new_x_shape.push(self.num_attention_heads);
+ new_x_shape.push(self.attention_head_size);
+ let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
+ xs.contiguous()
+ }
+
+ fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let query_layer = self.query.forward(hidden_states)?;
+ let key_layer = self.key.forward(hidden_states)?;
+ let value_layer = self.value.forward(hidden_states)?;
+
+ let query_layer = self.transpose_for_scores(&query_layer)?;
+ let key_layer = self.transpose_for_scores(&key_layer)?;
+ let value_layer = self.transpose_for_scores(&value_layer)?;
+
+ let attention_scores = query_layer.matmul(&key_layer.t()?)?;
+ let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
+ let attention_probs = {
+ let _enter_sm = self.span_softmax.enter();
+ candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
+ };
+ let attention_probs = self.dropout.forward(&attention_probs)?;
+
+ let context_layer = attention_probs.matmul(&value_layer)?;
+ let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
+ let context_layer = context_layer.flatten_from(candle::D::Minus2)?;
+ Ok(context_layer)
+ }
+}
+
+struct BertSelfOutput {
+ dense: Linear,
+ layer_norm: LayerNorm,
+ dropout: Dropout,
+ span: tracing::Span,
+}
+
+impl BertSelfOutput {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
+ let layer_norm = layer_norm(
+ config.hidden_size,
+ config.layer_norm_eps,
+ vb.pp("LayerNorm"),
+ )?;
+ let dropout = Dropout::new(config.hidden_dropout_prob);
+ Ok(Self {
+ dense,
+ layer_norm,
+ dropout,
+ span: tracing::span!(tracing::Level::TRACE, "self-out"),
+ })
+ }
+
+ fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let hidden_states = self.dense.forward(hidden_states)?;
+ let hidden_states = self.dropout.forward(&hidden_states)?;
+ self.layer_norm.forward(&(hidden_states + input_tensor)?)
+ }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392
+struct BertAttention {
+ self_attention: BertSelfAttention,
+ self_output: BertSelfOutput,
+ span: tracing::Span,
+}
+
+impl BertAttention {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
+ let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
+ Ok(Self {
+ self_attention,
+ self_output,
+ span: tracing::span!(tracing::Level::TRACE, "attn"),
+ })
+ }
+
+ fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let self_outputs = self.self_attention.forward(hidden_states)?;
+ let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
+ Ok(attention_output)
+ }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441
+struct BertIntermediate {
+ dense: Linear,
+ intermediate_act: HiddenActLayer,
+ span: tracing::Span,
+}
+
+impl BertIntermediate {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
+ Ok(Self {
+ dense,
+ intermediate_act: HiddenActLayer::new(config.hidden_act),
+ span: tracing::span!(tracing::Level::TRACE, "inter"),
+ })
+ }
+
+ fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let hidden_states = self.dense.forward(hidden_states)?;
+ let ys = self.intermediate_act.forward(&hidden_states)?;
+ Ok(ys)
+ }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456
+struct BertOutput {
+ dense: Linear,
+ layer_norm: LayerNorm,
+ dropout: Dropout,
+ span: tracing::Span,
+}
+
+impl BertOutput {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
+ let layer_norm = layer_norm(
+ config.hidden_size,
+ config.layer_norm_eps,
+ vb.pp("LayerNorm"),
+ )?;
+ let dropout = Dropout::new(config.hidden_dropout_prob);
+ Ok(Self {
+ dense,
+ layer_norm,
+ dropout,
+ span: tracing::span!(tracing::Level::TRACE, "out"),
+ })
+ }
+
+ fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let hidden_states = self.dense.forward(hidden_states)?;
+ let hidden_states = self.dropout.forward(&hidden_states)?;
+ self.layer_norm.forward(&(hidden_states + input_tensor)?)
+ }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470
+struct BertLayer {
+ attention: BertAttention,
+ intermediate: BertIntermediate,
+ output: BertOutput,
+ span: tracing::Span,
+}
+
+impl BertLayer {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let attention = BertAttention::load(vb.pp("attention"), config)?;
+ let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
+ let output = BertOutput::load(vb.pp("output"), config)?;
+ Ok(Self {
+ attention,
+ intermediate,
+ output,
+ span: tracing::span!(tracing::Level::TRACE, "layer"),
+ })
+ }
+
+ fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let attention_output = self.attention.forward(hidden_states)?;
+ // TODO: Support cross-attention?
+ // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
+ // TODO: Support something similar to `apply_chunking_to_forward`?
+ let intermediate_output = self.intermediate.forward(&attention_output)?;
+ let layer_output = self
+ .output
+ .forward(&intermediate_output, &attention_output)?;
+ Ok(layer_output)
+ }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
+struct BertEncoder {
+ layers: Vec<BertLayer>,
+ span: tracing::Span,
+}
+
+impl BertEncoder {
+ fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let layers = (0..config.num_hidden_layers)
+ .map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config))
+ .collect::<Result<Vec<_>>>()?;
+ let span = tracing::span!(tracing::Level::TRACE, "encoder");
+ Ok(BertEncoder { layers, span })
+ }
+
+ fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let mut hidden_states = hidden_states.clone();
+ // Use a loop rather than a fold as it's easier to modify when adding debug/...
+ for layer in self.layers.iter() {
+ hidden_states = layer.forward(&hidden_states)?
+ }
+ Ok(hidden_states)
+ }
+}
+
+// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
+pub struct BertModel {
+ embeddings: BertEmbeddings,
+ encoder: BertEncoder,
+ pub device: Device,
+ span: tracing::Span,
+}
+
+impl BertModel {
+ pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
+ let (embeddings, encoder) = match (
+ BertEmbeddings::load(vb.pp("embeddings"), config),
+ BertEncoder::load(vb.pp("encoder"), config),
+ ) {
+ (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
+ (Err(err), _) | (_, Err(err)) => {
+ if let Some(model_type) = &config.model_type {
+ if let (Ok(embeddings), Ok(encoder)) = (
+ BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
+ BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config),
+ ) {
+ (embeddings, encoder)
+ } else {
+ return Err(err);
+ }
+ } else {
+ return Err(err);
+ }
+ }
+ };
+ Ok(Self {
+ embeddings,
+ encoder,
+ device: vb.device().clone(),
+ span: tracing::span!(tracing::Level::TRACE, "model"),
+ })
+ }
+
+ pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
+ let sequence_output = self.encoder.forward(&embedding_output)?;
+ Ok(sequence_output)
+ }
+}
diff --git a/candle-transformers/src/models/bigcode.rs b/candle-transformers/src/models/bigcode.rs
new file mode 100644
index 00000000..1e63956b
--- /dev/null
+++ b/candle-transformers/src/models/bigcode.rs
@@ -0,0 +1,359 @@
+use candle::{DType, Device, IndexOp, Result, Tensor, D};
+use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
+
+fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
+ let weight = vb.get((size2, size1), "weight")?;
+ let bias = if bias {
+ Some(vb.get(size2, "bias")?)
+ } else {
+ None
+ };
+ Ok(Linear::new(weight, bias))
+}
+
+fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
+ let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
+ Ok(Embedding::new(embeddings, hidden_size))
+}
+
+fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
+ let weight = vb.get(size, "weight")?;
+ let bias = vb.get(size, "bias")?;
+ Ok(LayerNorm::new(weight, bias, eps))
+}
+
+fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> {
+ let mask: Vec<_> = (0..t)
+ .flat_map(|i| (0..t).map(move |j| u8::from(j <= i)))
+ .collect();
+ let mask = Tensor::from_slice(&mask, (t, t), device)?;
+ Ok(mask)
+}
+
+#[derive(Debug)]
+pub struct Config {
+ pub vocab_size: usize,
+ // max_position_embeddings aka n_positions
+ pub max_position_embeddings: usize,
+ // num_hidden_layers aka n_layer
+ pub num_hidden_layers: usize,
+ // hidden_size aka n_embd
+ pub hidden_size: usize,
+ pub layer_norm_epsilon: f64,
+ pub n_inner: Option<usize>,
+ // num_attention_heads aka n_head
+ pub num_attention_heads: usize,
+ pub multi_query: bool,
+ pub use_cache: bool,
+}
+
+impl Config {
+ #[allow(dead_code)]
+ pub fn starcoder_1b() -> Self {
+ Self {
+ vocab_size: 49152,
+ max_position_embeddings: 8192,
+ num_hidden_layers: 24,
+ hidden_size: 2048,
+ layer_norm_epsilon: 1e-5,
+ n_inner: Some(8192),
+ num_attention_heads: 16,
+ multi_query: true,
+ use_cache: true,
+ }
+ }
+
+ #[allow(dead_code)]
+ pub fn starcoder_3b() -> Self {
+ Self {
+ vocab_size: 49152,
+ max_position_embeddings: 8192,
+ num_hidden_layers: 36,
+ hidden_size: 2816,
+ layer_norm_epsilon: 1e-5,
+ n_inner: Some(11264),
+ num_attention_heads: 22,
+ multi_query: true,
+ use_cache: true,
+ }
+ }
+
+ #[allow(dead_code)]
+ pub fn starcoder_7b() -> Self {
+ Self {
+ vocab_size: 49152,
+ max_position_embeddings: 8192,
+ num_hidden_layers: 42,
+ hidden_size: 4096,
+ layer_norm_epsilon: 1e-5,
+ n_inner: Some(16384),
+ num_attention_heads: 32,
+ multi_query: true,
+ use_cache: true,
+ }
+ }
+
+ #[allow(dead_code)]
+ pub fn starcoder() -> Self {
+ Self {
+ vocab_size: 49152,
+ max_position_embeddings: 8192,
+ num_hidden_layers: 40,
+ hidden_size: 6144,
+ layer_norm_epsilon: 1e-5,
+ n_inner: Some(24576),
+ num_attention_heads: 48,
+ multi_query: true,
+ use_cache: true,
+ }
+ }
+}
+
+struct Attention {
+ c_attn: Linear,
+ c_proj: Linear,
+ kv_cache: Option<Tensor>,
+ use_cache: bool,
+ embed_dim: usize,
+ kv_dim: usize,
+ num_heads: usize,
+ head_dim: usize,
+ multi_query: bool,
+}
+
+impl Attention {
+ pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let hidden_size = cfg.hidden_size;
+ let head_dim = hidden_size / cfg.num_attention_heads;
+ let kv_heads = if cfg.multi_query {
+ 1
+ } else {
+ cfg.num_attention_heads
+ };
+ let kv_dim = kv_heads * head_dim;
+ let c_attn = linear(hidden_size, hidden_size + 2 * kv_dim, true, vb.pp("c_attn"))?;
+ let c_proj = linear(hidden_size, hidden_size, true, vb.pp("c_proj"))?;
+ Ok(Self {
+ c_proj,
+ c_attn,
+ embed_dim: hidden_size,
+ kv_cache: None,
+ use_cache: cfg.use_cache,
+ kv_dim,
+ head_dim,
+ num_heads: cfg.num_attention_heads,
+ multi_query: cfg.multi_query,
+ })
+ }
+
+ fn attn(
+ &self,
+ query: &Tensor,
+ key: &Tensor,
+ value: &Tensor,
+ attention_mask: &Tensor,
+ ) -> Result<Tensor> {
+ if query.dtype() != DType::F32 {
+ // If we start supporting f16 models, we may need the upcasting scaling bits.
+ // https://github.com/huggingface/transformers/blob/a0042379269bea9182c1f87e6b2eee4ba4c8cce8/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L133
+ candle::bail!("upcasting is not supported {:?}", query.dtype())
+ }
+ let scale_factor = 1f64 / (self.head_dim as f64).sqrt();
+ let initial_query_shape = query.shape();
+ let key_len = key.dim(D::Minus1)?;
+ let (query, key, attn_shape, attn_view) = if self.multi_query {
+ let (b_sz, query_len, _) = query.dims3()?;
+ let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?;
+ let attn_shape = (b_sz, query_len, self.num_heads, key_len);
+ let attn_view = (b_sz, query_len * self.num_heads, key_len);
+ (query, key.clone(), attn_shape, attn_view)
+ } else {
+ let (b_sz, _num_heads, query_len, _head_dim) = query.dims4()?;
+ let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?;
+ let key = key.reshape((b_sz * self.num_heads, self.head_dim, key_len))?;
+ let attn_shape = (b_sz, self.num_heads, query_len, key_len);
+ let attn_view = (b_sz * self.num_heads, query_len, key_len);
+ (query, key, attn_shape, attn_view)
+ };
+
+ let attn_weights =
+ (query.matmul(&key.contiguous()?)? * scale_factor)?.reshape(attn_shape)?;
+ let attention_mask = attention_mask.broadcast_as(attn_shape)?;
+ let mask_value =
+ Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
+ let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
+ let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
+ let value = value.contiguous()?;
+ let attn_output = if self.multi_query {
+ attn_weights
+ .reshape(attn_view)?
+ .matmul(&value)?
+ .reshape(initial_query_shape)?
+ } else {
+ attn_weights.matmul(&value)?
+ };
+ Ok(attn_output)
+ }
+
+ fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
+ let qkv = self.c_attn.forward(hidden_states)?;
+ let (query, key_value) = if self.multi_query {
+ let query = qkv.i((.., .., ..self.embed_dim))?;
+ let key_value = qkv.i((.., .., self.embed_dim..self.embed_dim + 2 * self.kv_dim))?;
+ (query, key_value)
+ } else {
+ let mut dims = qkv.dims().to_vec();
+ dims.pop();
+ dims.push(self.embed_dim);
+ dims.push(self.head_dim * 3);
+ let qkv = qkv.reshape(dims)?.transpose(1, 2)?;
+ let query = qkv.i((.., .., .., ..self.head_dim))?;
+ let key_value = qkv.i((.., .., .., self.head_dim..3 * self.head_dim))?;
+ (query, key_value)
+ };
+ let mut key_value = key_value;
+ if self.use_cache {
+ if let Some(kv_cache) = &self.kv_cache {
+ // TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for
+ // arbitrarily large sizes.
+ key_value = Tensor::cat(&[kv_cache, &key_value], D::Minus2)?.contiguous()?;
+ }
+ self.kv_cache = Some(key_value.clone())
+ }
+
+ let key = key_value.narrow(D::Minus1, 0, self.head_dim)?;
+ let value = key_value.narrow(D::Minus1, self.head_dim, self.head_dim)?;
+ let attn_output = self.attn(&query, &key.t()?, &value, attention_mask)?;
+ let attn_output = if self.multi_query {
+ attn_output
+ } else {
+ attn_output
+ .transpose(1, 2)?
+ .reshape(hidden_states.shape())?
+ };
+ let attn_output = self.c_proj.forward(&attn_output)?;
+ Ok(attn_output)
+ }
+}
+
+struct Mlp {
+ c_fc: Linear,
+ c_proj: Linear,
+}
+
+impl Mlp {
+ fn load(inner_dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let c_fc = linear(cfg.hidden_size, inner_dim, true, vb.pp("c_fc"))?;
+ let c_proj = linear(inner_dim, cfg.hidden_size, true, vb.pp("c_proj"))?;
+ Ok(Self { c_fc, c_proj })
+ }
+
+ fn forward(&mut self, hidden_states: &Tensor) -> Result<Tensor> {
+ let hidden_states = self.c_fc.forward(hidden_states)?.gelu()?;
+ let hidden_states = self.c_proj.forward(&hidden_states)?;
+ Ok(hidden_states)
+ }
+}
+
+// TODO: Add cross-attention?
+struct Block {
+ ln_1: LayerNorm,
+ attn: Attention,
+ ln_2: LayerNorm,
+ mlp: Mlp,
+}
+
+impl Block {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let hidden_size = cfg.hidden_size;
+ let inner_dim = cfg.n_inner.unwrap_or(4 * hidden_size);
+ let ln_1 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_1"))?;
+ let attn = Attention::load(vb.pp("attn"), cfg)?;
+ let ln_2 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_2"))?;
+ let mlp = Mlp::load(inner_dim, vb.pp("mlp"), cfg)?;
+ Ok(Self {
+ ln_1,
+ attn,
+ ln_2,
+ mlp,
+ })
+ }
+
+ fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
+ let residual = hidden_states;
+ let hidden_states = self.ln_1.forward(hidden_states)?;
+ let attn_outputs = self.attn.forward(&hidden_states, attention_mask)?;
+ let hidden_states = (&attn_outputs + residual)?;
+ let residual = &hidden_states;
+ let hidden_states = self.ln_2.forward(&hidden_states)?;
+ let hidden_states = self.mlp.forward(&hidden_states)?;
+ let hidden_states = (&hidden_states + residual)?;
+ Ok(hidden_states)
+ }
+}
+
+pub struct GPTBigCode {
+ wte: Embedding,
+ wpe: Embedding,
+ blocks: Vec<Block>,
+ ln_f: LayerNorm,
+ lm_head: Linear,
+ bias: Tensor,
+ config: Config,
+}
+
+impl GPTBigCode {
+ pub fn config(&self) -> &Config {
+ &self.config
+ }
+
+ pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
+ let hidden_size = cfg.hidden_size;
+ let vb_t = vb.pp("transformer");
+ let wte = embedding(cfg.vocab_size, hidden_size, vb_t.pp("wte"))?;
+ let wpe = embedding(cfg.max_position_embeddings, hidden_size, vb_t.pp("wpe"))?;
+ let blocks = (0..cfg.num_hidden_layers)
+ .map(|i| Block::load(vb_t.pp(&format!("h.{i}")), &cfg))
+ .collect::<Result<Vec<_>>>()?;
+ let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?;
+ let lm_head = linear(hidden_size, cfg.vocab_size, false, vb_t.pp("wte"))?;
+ let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?;
+ Ok(Self {
+ wte,
+ wpe,
+ blocks,
+ lm_head,
+ ln_f,
+ bias,
+ config: cfg,
+ })
+ }
+
+ pub fn forward(&mut self, input_ids: &Tensor, past_len: usize) -> Result<Tensor> {
+ let dev = input_ids.device();
+ let (b_sz, seq_len) = input_ids.dims2()?;
+
+ let key_len = past_len + seq_len;
+ let attention_mask = self.bias.i((past_len..key_len, ..key_len))?.unsqueeze(0)?;
+ // MQA models: (batch_size, query_length, n_heads, key_length)
+ // MHA models: (batch_size, n_heads, query_length, key_length)
+ let seq_len_dim = if self.config.multi_query { 2 } else { 1 };
+ let attention_mask = attention_mask.unsqueeze(seq_len_dim)?;
+
+ let position_ids = Tensor::arange(past_len as u32, (past_len + seq_len) as u32, dev)?;
+ let position_ids = position_ids.unsqueeze(0)?.broadcast_as((b_sz, seq_len))?;
+ let input_embeds = self.wte.forward(input_ids)?;
+ let position_embeds = self.wpe.forward(&position_ids)?;
+
+ let mut hidden_states = (&input_embeds + &position_embeds)?;
+ for block in self.blocks.iter_mut() {
+ hidden_states = block.forward(&hidden_states, &attention_mask)?;
+ }
+ let hidden_states = self.ln_f.forward(&hidden_states)?;
+ let hidden_states = hidden_states
+ .reshape((b_sz, seq_len, self.config.hidden_size))?
+ .narrow(1, seq_len - 1, 1)?;
+ let logits = self.lm_head.forward(&hidden_states)?.squeeze(1)?;
+ Ok(logits)
+ }
+}
diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs
new file mode 100644
index 00000000..0edc8494
--- /dev/null
+++ b/candle-transformers/src/models/dinov2.rs
@@ -0,0 +1,279 @@
+use candle::{IndexOp, Result, Tensor, D};
+use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+
+const IMG_SIZE: usize = 518;
+const PATCH_SIZE: usize = 14;
+const NUM_CLASSES: usize = 1000;
+
+fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
+ if bias {
+ candle_nn::linear(in_dim, out_dim, vb)
+ } else {
+ candle_nn::linear_no_bias(in_dim, out_dim, vb)
+ }
+}
+
+#[derive(Debug)]
+struct Attention {
+ qkv: Linear,
+ proj: Linear,
+ num_heads: usize,
+ scale: f64,
+}
+
+impl Attention {
+ fn new(
+ vb: VarBuilder,
+ dim: usize,
+ num_heads: usize,
+ qkv_bias: bool,
+ proj_bias: bool,
+ ) -> Result<Self> {
+ let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
+ let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
+ let scale = 1. / ((dim / num_heads) as f64).sqrt();
+ Ok(Self {
+ qkv,
+ proj,
+ num_heads,
+ scale,
+ })
+ }
+}
+
+impl Module for Attention {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let (b, n, c) = xs.dims3()?;
+ let qkv = self
+ .qkv
+ .forward(xs)?
+ .reshape((b, n, 3, self.num_heads, c / self.num_heads))?
+ .transpose(1, 2)? // 02134
+ .transpose(0, 1)? // 20134
+ .transpose(2, 3)?; // 20314
+ let q = (qkv.i(0)? * self.scale)?;
+ let k = qkv.i(1)?;
+ let v = qkv.i(2)?;
+ let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
+ let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
+ self.proj.forward(&attn)
+ }
+}
+
+#[derive(Debug)]
+struct LayerScale {
+ gamma: Tensor,
+}
+
+impl LayerScale {
+ fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
+ let gamma = vb.get(dim, "gamma")?;
+ Ok(Self { gamma })
+ }
+}
+
+impl Module for LayerScale {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ xs.broadcast_mul(&self.gamma)
+ }
+}
+
+#[derive(Debug)]
+struct Mlp {
+ fc1: Linear,
+ fc2: Linear,
+}
+
+impl Mlp {
+ fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
+ let out_features = in_features;
+ let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
+ let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
+ Ok(Self { fc1, fc2 })
+ }
+}
+
+impl Module for Mlp {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.fc1.forward(xs)?.gelu()?;
+ self.fc2.forward(&xs)
+ }
+}
+
+#[derive(Debug)]
+struct Block {
+ norm1: LayerNorm,
+ attn: Attention,
+ ls1: LayerScale,
+ norm2: LayerNorm,
+ mlp: Mlp,
+ ls2: LayerScale,
+}
+
+impl Block {
+ fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
+ let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
+ let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
+ let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
+ let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
+ let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
+ let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
+ Ok(Self {
+ norm1,
+ attn,
+ ls1,
+ norm2,
+ mlp,
+ ls2,
+ })
+ }
+}
+
+impl Module for Block {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let residual = xs;
+ let xs = self
+ .ls1
+ .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
+ let xs = (xs + residual)?;
+ let residual = &xs;
+ let xs = self
+ .ls2
+ .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
+ xs + residual
+ }
+}
+
+#[derive(Debug)]
+struct PatchEmbed {
+ proj: candle_nn::Conv2d,
+ patch_size: (usize, usize),
+ num_patches: usize,
+}
+
+impl PatchEmbed {
+ fn new(
+ vb: VarBuilder,
+ img_size: usize,
+ patch_size: usize,
+ in_chans: usize,
+ embed_dim: usize,
+ ) -> Result<Self> {
+ let config = candle_nn::Conv2dConfig {
+ stride: patch_size,
+ ..Default::default()
+ };
+ let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
+ let num_patches = (img_size / patch_size) * (img_size / patch_size);
+ Ok(Self {
+ proj,
+ patch_size: (patch_size, patch_size),
+ num_patches,
+ })
+ }
+}
+
+impl Module for PatchEmbed {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let (_b, _c, h, w) = xs.dims4()?;
+ let (patch_h, patch_w) = self.patch_size;
+ if (h % patch_h) != 0 {
+ candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
+ }
+ if (w % patch_w) != 0 {
+ candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
+ }
+ let xs = self.proj.forward(xs)?;
+ let (b, c, h, w) = xs.dims4()?;
+ // flatten embeddings.
+ xs.reshape((b, c, h * w))?.transpose(1, 2)
+ }
+}
+
+#[derive(Debug)]
+pub struct DinoVisionTransformer {
+ patch_embed: PatchEmbed,
+ cls_token: Tensor,
+ pos_embed: Tensor,
+ blocks: Vec<Block>,
+ norm: LayerNorm,
+ head: Linear,
+}
+
+impl DinoVisionTransformer {
+ pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
+ let patch_embed =
+ PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
+ let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
+ let num_tokens = 1;
+ let pos_embed = vb.get(
+ (1, patch_embed.num_patches + num_tokens, embed_dim),
+ "pos_embed",
+ )?;
+ let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
+ let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
+ let vb_b = vb.pp("blocks");
+ let blocks = (0..depth)
+ .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Self {
+ patch_embed,
+ cls_token,
+ pos_embed,
+ blocks,
+ norm,
+ head,
+ })
+ }
+
+ fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
+ let npatch = xs.dim(1)? - 1;
+ let n = self.pos_embed.dim(1)? - 1;
+ let sqrt_n = (n as f64).sqrt();
+ if npatch == n && w == h {
+ return Ok(xs.clone());
+ }
+ let class_pos_embed = self.pos_embed.i((.., ..1))?;
+ let patch_pos_embed = self.pos_embed.i((.., 1..))?;
+ let dim = xs.dim(D::Minus1)?;
+ let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
+ let patch_pos_embed = patch_pos_embed
+ .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
+ .transpose(2, 3)?
+ .transpose(1, 2)?;
+ // This uses bicubic interpolation in the original implementation.
+ let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
+ let el_count = patch_pos_embed.shape().elem_count();
+ let patch_pos_embed =
+ patch_pos_embed
+ .transpose(1, 2)?
+ .transpose(2, 3)?
+ .reshape((1, el_count / dim, dim))?;
+ Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
+ }
+
+ fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
+ let (_b, _nc, w, h) = xs.dims4()?;
+ let xs = self.patch_embed.forward(xs)?;
+ let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
+ &xs + &self.interpolate_pos_encoding(&xs, w, h)?
+ }
+}
+
+impl Module for DinoVisionTransformer {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = self.prepare_tokens_with_mask(xs)?;
+ for blk in self.blocks.iter() {
+ xs = blk.forward(&xs)?
+ }
+ let xs = self.norm.forward(&xs)?;
+ let xs_norm_clstoken = xs.i((.., 0))?;
+ let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
+ let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
+ self.head.forward(&xs)
+ }
+}
+
+pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
+ DinoVisionTransformer::new(vb, 12, 384, 6)
+}
diff --git a/candle-transformers/src/models/efficientnet.rs b/candle-transformers/src/models/efficientnet.rs
new file mode 100644
index 00000000..ab51c76d
--- /dev/null
+++ b/candle-transformers/src/models/efficientnet.rs
@@ -0,0 +1,331 @@
+use candle::{Result, Tensor, D};
+use candle_nn as nn;
+use nn::{Module, VarBuilder};
+
+// Based on the Python version from torchvision.
+// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47
+#[derive(Debug, Clone, Copy)]
+pub struct MBConvConfig {
+ expand_ratio: f64,
+ kernel: usize,
+ stride: usize,
+ input_channels: usize,
+ out_channels: usize,
+ num_layers: usize,
+}
+
+fn make_divisible(v: f64, divisor: usize) -> usize {
+ let min_value = divisor;
+ let new_v = usize::max(
+ min_value,
+ (v + divisor as f64 * 0.5) as usize / divisor * divisor,
+ );
+ if (new_v as f64) < 0.9 * v {
+ new_v + divisor
+ } else {
+ new_v
+ }
+}
+
+fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
+ let bneck_conf = |e, k, s, i, o, n| {
+ let input_channels = make_divisible(i as f64 * width_mult, 8);
+ let out_channels = make_divisible(o as f64 * width_mult, 8);
+ let num_layers = (n as f64 * depth_mult).ceil() as usize;
+ MBConvConfig {
+ expand_ratio: e,
+ kernel: k,
+ stride: s,
+ input_channels,
+ out_channels,
+ num_layers,
+ }
+ };
+ vec![
+ bneck_conf(1., 3, 1, 32, 16, 1),
+ bneck_conf(6., 3, 2, 16, 24, 2),
+ bneck_conf(6., 5, 2, 24, 40, 2),
+ bneck_conf(6., 3, 2, 40, 80, 3),
+ bneck_conf(6., 5, 1, 80, 112, 3),
+ bneck_conf(6., 5, 2, 112, 192, 4),
+ bneck_conf(6., 3, 1, 192, 320, 1),
+ ]
+}
+
+impl MBConvConfig {
+ pub fn b0() -> Vec<Self> {
+ bneck_confs(1.0, 1.0)
+ }
+ pub fn b1() -> Vec<Self> {
+ bneck_confs(1.0, 1.1)
+ }
+ pub fn b2() -> Vec<Self> {
+ bneck_confs(1.1, 1.2)
+ }
+ pub fn b3() -> Vec<Self> {
+ bneck_confs(1.2, 1.4)
+ }
+ pub fn b4() -> Vec<Self> {
+ bneck_confs(1.4, 1.8)
+ }
+ pub fn b5() -> Vec<Self> {
+ bneck_confs(1.6, 2.2)
+ }
+ pub fn b6() -> Vec<Self> {
+ bneck_confs(1.8, 2.6)
+ }
+ pub fn b7() -> Vec<Self> {
+ bneck_confs(2.0, 3.1)
+ }
+}
+
+/// Conv2D with same padding.
+#[derive(Debug)]
+struct Conv2DSame {
+ conv2d: nn::Conv2d,
+ s: usize,
+ k: usize,
+}
+
+impl Conv2DSame {
+ fn new(
+ vb: VarBuilder,
+ i: usize,
+ o: usize,
+ k: usize,
+ stride: usize,
+ groups: usize,
+ bias: bool,
+ ) -> Result<Self> {
+ let conv_config = nn::Conv2dConfig {
+ stride,
+ groups,
+ ..Default::default()
+ };
+ let conv2d = if bias {
+ nn::conv2d(i, o, k, conv_config, vb)?
+ } else {
+ nn::conv2d_no_bias(i, o, k, conv_config, vb)?
+ };
+ Ok(Self {
+ conv2d,
+ s: stride,
+ k,
+ })
+ }
+}
+
+impl Module for Conv2DSame {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let s = self.s;
+ let k = self.k;
+ let (_, _, ih, iw) = xs.dims4()?;
+ let oh = (ih + s - 1) / s;
+ let ow = (iw + s - 1) / s;
+ let pad_h = usize::max((oh - 1) * s + k - ih, 0);
+ let pad_w = usize::max((ow - 1) * s + k - iw, 0);
+ if pad_h > 0 || pad_w > 0 {
+ let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
+ let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
+ self.conv2d.forward(&xs)
+ } else {
+ self.conv2d.forward(xs)
+ }
+ }
+}
+
+#[derive(Debug)]
+struct ConvNormActivation {
+ conv2d: Conv2DSame,
+ bn2d: nn::BatchNorm,
+ activation: bool,
+}
+
+impl ConvNormActivation {
+ fn new(
+ vb: VarBuilder,
+ i: usize,
+ o: usize,
+ k: usize,
+ stride: usize,
+ groups: usize,
+ ) -> Result<Self> {
+ let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
+ let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
+ Ok(Self {
+ conv2d,
+ bn2d,
+ activation: true,
+ })
+ }
+
+ fn no_activation(self) -> Self {
+ Self {
+ activation: false,
+ ..self
+ }
+ }
+}
+
+impl Module for ConvNormActivation {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.conv2d.forward(xs)?;
+ let xs = self.bn2d.forward(&xs)?;
+ if self.activation {
+ swish(&xs)
+ } else {
+ Ok(xs)
+ }
+ }
+}
+
+#[derive(Debug)]
+struct SqueezeExcitation {
+ fc1: Conv2DSame,
+ fc2: Conv2DSame,
+}
+
+impl SqueezeExcitation {
+ fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> {
+ let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
+ let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
+ Ok(Self { fc1, fc2 })
+ }
+}
+
+impl Module for SqueezeExcitation {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let residual = xs;
+ // equivalent to adaptive_avg_pool2d([1, 1])
+ let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
+ let xs = self.fc1.forward(&xs)?;
+ let xs = swish(&xs)?;
+ let xs = self.fc2.forward(&xs)?;
+ let xs = nn::ops::sigmoid(&xs)?;
+ residual.broadcast_mul(&xs)
+ }
+}
+
+#[derive(Debug)]
+struct MBConv {
+ expand_cna: Option<ConvNormActivation>,
+ depthwise_cna: ConvNormActivation,
+ squeeze_excitation: SqueezeExcitation,
+ project_cna: ConvNormActivation,
+ config: MBConvConfig,
+}
+
+impl MBConv {
+ fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> {
+ let vb = vb.pp("block");
+ let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
+ let expand_cna = if exp != c.input_channels {
+ Some(ConvNormActivation::new(
+ vb.pp("0"),
+ c.input_channels,
+ exp,
+ 1,
+ 1,
+ 1,
+ )?)
+ } else {
+ None
+ };
+ let start_index = if expand_cna.is_some() { 1 } else { 0 };
+ let depthwise_cna =
+ ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
+ let squeeze_channels = usize::max(1, c.input_channels / 4);
+ let squeeze_excitation =
+ SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
+ let project_cna =
+ ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
+ .no_activation();
+ Ok(Self {
+ expand_cna,
+ depthwise_cna,
+ squeeze_excitation,
+ project_cna,
+ config: c,
+ })
+ }
+}
+
+impl Module for MBConv {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let use_res_connect =
+ self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
+ let ys = match &self.expand_cna {
+ Some(expand_cna) => expand_cna.forward(xs)?,
+ None => xs.clone(),
+ };
+ let ys = self.depthwise_cna.forward(&ys)?;
+ let ys = self.squeeze_excitation.forward(&ys)?;
+ let ys = self.project_cna.forward(&ys)?;
+ if use_res_connect {
+ ys + xs
+ } else {
+ Ok(ys)
+ }
+ }
+}
+
+fn swish(s: &Tensor) -> Result<Tensor> {
+ s * nn::ops::sigmoid(s)?
+}
+
+#[derive(Debug)]
+pub struct EfficientNet {
+ init_cna: ConvNormActivation,
+ blocks: Vec<MBConv>,
+ final_cna: ConvNormActivation,
+ classifier: nn::Linear,
+}
+
+impl EfficientNet {
+ pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
+ let f_p = p.pp("features");
+ let first_in_c = configs[0].input_channels;
+ let last_out_c = configs.last().unwrap().out_channels;
+ let final_out_c = 4 * last_out_c;
+ let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
+ let nconfigs = configs.len();
+ let mut blocks = vec![];
+ for (index, cnf) in configs.into_iter().enumerate() {
+ let f_p = f_p.pp(index + 1);
+ for r_index in 0..cnf.num_layers {
+ let cnf = if r_index == 0 {
+ cnf
+ } else {
+ MBConvConfig {
+ input_channels: cnf.out_channels,
+ stride: 1,
+ ..cnf
+ }
+ };
+ blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
+ }
+ }
+ let final_cna =
+ ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
+ let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
+ Ok(Self {
+ init_cna,
+ blocks,
+ final_cna,
+ classifier,
+ })
+ }
+}
+
+impl Module for EfficientNet {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = self.init_cna.forward(xs)?;
+ for block in self.blocks.iter() {
+ xs = block.forward(&xs)?
+ }
+ let xs = self.final_cna.forward(&xs)?;
+ // Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1)
+ let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
+ self.classifier.forward(&xs)
+ }
+}
diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs
new file mode 100644
index 00000000..6ede136a
--- /dev/null
+++ b/candle-transformers/src/models/falcon.rs
@@ -0,0 +1,484 @@
+use candle::{DType, Device, Result, Tensor, D};
+use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
+
+const MAX_SEQ_LEN: usize = 5000;
+
+fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
+ let weight = vb.get((size2, size1), "weight")?;
+ let bias = if bias {
+ Some(vb.get(size2, "bias")?)
+ } else {
+ None
+ };
+ Ok(Linear::new(weight, bias))
+}
+
+fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
+ let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
+ (Ok(weight), Ok(bias)) => (weight, bias),
+ (Err(err), _) | (_, Err(err)) => {
+ if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
+ (weight, bias)
+ } else {
+ return Err(err);
+ }
+ }
+ };
+ Ok(LayerNorm::new(weight, bias, eps))
+}
+
+fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
+ let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
+ Ok(Embedding::new(embeddings, hidden_size))
+}
+
+// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py
+#[derive(Debug)]
+pub struct Config {
+ pub vocab_size: usize,
+ pub hidden_size: usize,
+ pub num_hidden_layers: usize,
+ pub num_attention_heads: usize,
+ pub layer_norm_epsilon: f64,
+ pub initializer_range: f64,
+ pub use_cache: bool,
+ pub bos_token_id: u32,
+ pub eos_token_id: u32,
+ pub hidden_dropout: f64,
+ pub attention_dropout: f64,
+ pub n_head_kv: Option<usize>,
+ pub alibi: bool,
+ pub new_decoder_architecture: bool,
+ pub multi_query: bool,
+ pub parallel_attn: bool,
+ pub bias: bool,
+}
+
+impl Default for Config {
+ fn default() -> Self {
+ Self {
+ vocab_size: 65024,
+ hidden_size: 4544,
+ num_hidden_layers: 32,
+ num_attention_heads: 71,
+ layer_norm_epsilon: 1e-5,
+ initializer_range: 0.02,
+ use_cache: true,
+ bos_token_id: 11,
+ eos_token_id: 11,
+ hidden_dropout: 0.0,
+ attention_dropout: 0.0,
+ n_head_kv: None,
+ alibi: false,
+ new_decoder_architecture: false,
+ multi_query: true,
+ parallel_attn: true,
+ bias: false,
+ }
+ }
+}
+
+impl Config {
+ pub fn validate(&self) -> Result<()> {
+ if self.alibi {
+ candle::bail!("alibi is not supported");
+ }
+ if self.new_decoder_architecture {
+ candle::bail!("new_decoder_architecture is not supported");
+ }
+ if self.n_head_kv.is_some() {
+ candle::bail!("n_head_kv is not supported");
+ }
+ Ok(())
+ }
+
+ // https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json
+ pub fn falcon7b() -> Self {
+ // This is currently on par with the defaults, the defaults come from the Python default
+ // arguments for the config initialization whereas the following come from the json config.
+ Self {
+ vocab_size: 65024,
+ hidden_size: 4544,
+ num_hidden_layers: 32,
+ num_attention_heads: 71,
+ layer_norm_epsilon: 1e-5,
+ initializer_range: 0.02,
+ use_cache: true,
+ bos_token_id: 11,
+ eos_token_id: 11,
+ hidden_dropout: 0.,
+ attention_dropout: 0.,
+ n_head_kv: None,
+ alibi: false,
+ new_decoder_architecture: false,
+ multi_query: true,
+ parallel_attn: true,
+ bias: false,
+ }
+ }
+
+ fn head_dim(&self) -> usize {
+ self.hidden_size / self.num_attention_heads
+ }
+
+ fn rotary(&self) -> bool {
+ !self.alibi
+ }
+}
+
+fn rotate_half(x: &Tensor) -> Result<Tensor> {
+ let l = x.dim(D::Minus1)?;
+ let x1 = x.narrow(D::Minus1, 0, l / 2)?;
+ let x2 = x.narrow(D::Minus1, l / 2, l - l / 2)?;
+ let x21 = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
+ Ok(x21)
+}
+
+#[derive(Debug)]
+struct FalconRotaryEmbedding {
+ inv_freq: Tensor,
+ cache: Option<(usize, Tensor, Tensor)>,
+}
+
+impl FalconRotaryEmbedding {
+ fn load(device: &Device, cfg: &Config) -> Result<Self> {
+ let head_dim = cfg.head_dim();
+ let inv_freq: Vec<_> = (0..head_dim)
+ .step_by(2)
+ .map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
+ .collect();
+ Ok(Self {
+ inv_freq: Tensor::new(inv_freq.as_slice(), device)?,
+ cache: None,
+ })
+ }
+
+ fn cos_sin(
+ &mut self,
+ seq_len: usize,
+ device: &Device,
+ dtype: DType,
+ ) -> Result<(Tensor, Tensor)> {
+ match &self.cache {
+ Some((s, cos, sin)) if *s == seq_len => {
+ return Ok((cos.clone(), sin.clone()));
+ }
+ _ => {}
+ }
+ let t = Tensor::arange(0, seq_len as u32, device)?.to_dtype(dtype)?;
+ let inv_freq = self.inv_freq.to_dtype(dtype)?;
+ let freqs = t.unsqueeze(1)?.matmul(&inv_freq.unsqueeze(0)?)?;
+ let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
+ let cos = emb.cos()?;
+ let sin = emb.sin()?;
+ self.cache = Some((seq_len, cos.clone(), sin.clone()));
+ Ok((cos, sin))
+ }
+
+ fn forward(
+ &mut self,
+ query: &Tensor,
+ key: &Tensor,
+ past_kv_len: usize,
+ ) -> Result<(Tensor, Tensor)> {
+ let (_batch, seq_len, _head_dim) = query.dims3()?;
+ let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?;
+ let cos = cos.narrow(0, past_kv_len, seq_len)?;
+ let sin = sin.narrow(0, past_kv_len, seq_len)?;
+ let qs = (query.broadcast_mul(&cos)? + &rotate_half(query)?.broadcast_mul(&sin)?)?;
+ let ks = (key.broadcast_mul(&cos)? + &rotate_half(key)?.broadcast_mul(&sin)?)?;
+ Ok((qs, ks))
+ }
+}
+
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
+ let shape = mask.shape();
+ let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
+ let m = mask.where_cond(&on_true, on_false)?;
+ Ok(m)
+}
+
+#[derive(Debug)]
+struct FalconAttention {
+ query_key_value: Linear,
+ dense: Linear,
+ maybe_rotary: Option<FalconRotaryEmbedding>,
+ kv_cache: Option<(Tensor, Tensor)>,
+ inv_norm_factor: f64,
+ multi_query: bool,
+ use_cache: bool,
+ num_heads: usize,
+ head_dim: usize,
+ n_head_kv: usize,
+}
+
+impl FalconAttention {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let maybe_rotary = if cfg.rotary() {
+ let rotary = FalconRotaryEmbedding::load(vb.device(), cfg)?;
+ Some(rotary)
+ } else {
+ None
+ };
+ let head_dim = cfg.head_dim();
+ let hidden_size = cfg.hidden_size;
+ let qkv_out_dim = if cfg.multi_query {
+ hidden_size + 2 * head_dim
+ } else {
+ 3 * hidden_size
+ };
+ let query_key_value = linear(hidden_size, qkv_out_dim, cfg.bias, vb.pp("query_key_value"))?;
+ let dense = linear(hidden_size, hidden_size, cfg.bias, vb.pp("dense"))?;
+ Ok(Self {
+ query_key_value,
+ dense,
+ maybe_rotary,
+ kv_cache: None,
+ inv_norm_factor: 1. / (head_dim as f64).sqrt(),
+ multi_query: cfg.multi_query,
+ use_cache: cfg.use_cache,
+ num_heads: cfg.num_attention_heads,
+ n_head_kv: cfg.n_head_kv.unwrap_or(1),
+ head_dim,
+ })
+ }
+
+ fn split_heads(&self, fused_qkv: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
+ let (b_sz, seq_len, _) = fused_qkv.dims3()?;
+ if !self.multi_query {
+ let fused_qkv = fused_qkv.reshape((b_sz, seq_len, self.num_heads, 3, self.head_dim))?;
+ let q = fused_qkv.narrow(D::Minus2, 0, 1)?.squeeze(D::Minus2)?;
+ let k = fused_qkv.narrow(D::Minus2, 1, 1)?.squeeze(D::Minus2)?;
+ let v = fused_qkv.narrow(D::Minus2, 2, 1)?.squeeze(D::Minus2)?;
+ Ok((q, k, v))
+ } else {
+ let fused_qkv =
+ fused_qkv.reshape((b_sz, seq_len, self.num_heads + 2, self.head_dim))?;
+ let d = fused_qkv.dim(D::Minus2)?;
+ let q = fused_qkv.narrow(D::Minus2, 0, d - 2)?;
+ let k = fused_qkv.narrow(D::Minus2, d - 2, 1)?;
+ let v = fused_qkv.narrow(D::Minus2, d - 1, 1)?;
+ Ok((q, k, v))
+ }
+ }
+
+ fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> {
+ let fused_qkv = self.query_key_value.forward(x)?;
+ let head_dim = self.head_dim;
+ let (query, key, value) = self.split_heads(&fused_qkv)?;
+ let (b_sz, seq_len, _, _) = query.dims4()?;
+ let query = query
+ .transpose(1, 2)?
+ .reshape((b_sz * self.num_heads, seq_len, head_dim))?;
+ let key = key
+ .transpose(1, 2)?
+ .reshape((b_sz * self.n_head_kv, seq_len, head_dim))?;
+ let value = value
+ .transpose(1, 2)?
+ .reshape((b_sz * self.n_head_kv, seq_len, head_dim))?;
+ let (query, key) = if let Some(r) = &mut self.maybe_rotary {
+ r.forward(&query, &key, past_kv_len)?
+ } else {
+ (query, key)
+ };
+ let (mut key, mut value) = (key, value);
+ let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?.to_dtype(query.dtype())?;
+ if self.use_cache {
+ if let Some((cache_k, cache_v)) = &self.kv_cache {
+ // TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for
+ // arbitrarily large sizes.
+ key = Tensor::cat(&[cache_k, &key], 1)?.contiguous()?;
+ value = Tensor::cat(&[cache_v, &value], 1)?.contiguous()?;
+ }
+ self.kv_cache = Some((key.clone(), value.clone()))
+ }
+ let query = query.reshape((b_sz * self.num_heads, seq_len, head_dim))?;
+ let all_len = past_kv_len + seq_len;
+ let key = key.reshape((b_sz * self.n_head_kv, all_len, head_dim))?;
+ let value = value.reshape((b_sz * self.n_head_kv, all_len, head_dim))?;
+
+ let (key, value) = if self.n_head_kv == 1 {
+ (
+ key.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?,
+ value.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?,
+ )
+ } else {
+ (key, value)
+ };
+
+ // Only handle the case where alibi is None here, and non-flash attention.
+ let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
+ let attention_scores = candle_nn::ops::softmax(
+ &attention_scores
+ .broadcast_add(&mask.squeeze(1)?)?
+ .to_dtype(DType::F32)?,
+ D::Minus1,
+ )?
+ .to_dtype(x.dtype())?;
+ let attn_output = attention_scores
+ .matmul(&value)?
+ .reshape((b_sz, self.num_heads, seq_len, head_dim))?
+ .transpose(1, 2)?
+ .reshape((b_sz, seq_len, self.num_heads * head_dim))?;
+ let attn_output = self.dense.forward(&attn_output)?;
+ Ok(attn_output)
+ }
+}
+
+#[derive(Debug)]
+struct FalconMlp {
+ dense_h_to_4h: Linear,
+ dense_4h_to_h: Linear,
+}
+
+impl FalconMlp {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let h = cfg.hidden_size;
+ let b = cfg.bias;
+ let dense_h_to_4h = linear(h, 4 * h, b, vb.pp("dense_h_to_4h"))?;
+ let dense_4h_to_h = linear(4 * h, h, b, vb.pp("dense_4h_to_h"))?;
+ Ok(Self {
+ dense_h_to_4h,
+ dense_4h_to_h,
+ })
+ }
+
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let x = self.dense_h_to_4h.forward(x)?.gelu()?;
+ let x = self.dense_4h_to_h.forward(&x)?;
+ Ok(x)
+ }
+}
+
+#[derive(Debug)]
+struct FalconDecoderLayer {
+ inp_layernorm: LayerNorm,
+ self_attention: FalconAttention,
+ post_attention_layernorm: Option<LayerNorm>,
+ mlp: FalconMlp,
+ parallel_attn: bool,
+}
+
+impl FalconDecoderLayer {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let mlp = FalconMlp::load(vb.pp("mlp"), cfg)?;
+ let inp_layernorm = layer_norm(
+ cfg.hidden_size,
+ cfg.layer_norm_epsilon,
+ vb.pp("input_layernorm"),
+ )?;
+ let self_attention = FalconAttention::load(vb.pp("self_attention"), cfg)?;
+ let post_attention_layernorm = if cfg.parallel_attn {
+ None
+ } else {
+ let ln = layer_norm(
+ cfg.hidden_size,
+ cfg.layer_norm_epsilon,
+ vb.pp("post_attention_layernorm"),
+ )?;
+ Some(ln)
+ };
+ Ok(Self {
+ inp_layernorm,
+ self_attention,
+ post_attention_layernorm,
+ mlp,
+ parallel_attn: cfg.parallel_attn,
+ })
+ }
+
+ fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> {
+ let residual = x.clone();
+ let ln_attn = self.inp_layernorm.forward(x)?;
+ let attn_output = self.self_attention.forward(&ln_attn, mask, past_kv_len)?;
+ let (residual, ln_mlp) = match &self.post_attention_layernorm {
+ None => (residual, ln_attn),
+ Some(pal) => {
+ // This should include some dropout.
+ let residual = (&attn_output + &residual)?;
+ let ln_mlp = pal.forward(&residual)?;
+ (residual, ln_mlp)
+ }
+ };
+ let mlp_output = self.mlp.forward(&ln_mlp)?;
+
+ let mlp_output = if self.parallel_attn {
+ (mlp_output + attn_output)?
+ } else {
+ mlp_output
+ };
+ let output = (mlp_output + residual)?;
+ Ok(output)
+ }
+}
+
+#[derive(Debug)]
+pub struct Falcon {
+ word_embeddings: Embedding,
+ blocks: Vec<FalconDecoderLayer>,
+ ln_f: LayerNorm,
+ lm_head: Linear,
+ config: Config,
+}
+
+fn make_causal_mask(t: usize) -> Result<Tensor> {
+ let mask: Vec<_> = (0..t)
+ .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
+ .collect();
+ let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
+ Ok(mask)
+}
+
+fn prepare_attn_mask(b_sz: usize, seq_len: usize) -> Result<Tensor> {
+ // let mask = Tensor::ones((b_sz, seq_len), DType::U32, &Device::Cpu)?;
+ let mask = make_causal_mask(seq_len)?;
+ let mask = mask.broadcast_as((b_sz, 1, seq_len, seq_len))?;
+ Ok(mask)
+}
+
+impl Falcon {
+ pub fn config(&self) -> &Config {
+ &self.config
+ }
+
+ pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
+ let word_embeddings = embedding(
+ cfg.vocab_size,
+ cfg.hidden_size,
+ vb.pp("transformer.word_embeddings"),
+ )?;
+ let blocks = (0..cfg.num_hidden_layers)
+ .map(|i| FalconDecoderLayer::load(vb.pp(&format!("transformer.h.{i}")), &cfg))
+ .collect::<Result<Vec<_>>>()?;
+ let ln_f = layer_norm(
+ cfg.hidden_size,
+ cfg.layer_norm_epsilon,
+ vb.pp("transformer.ln_f"),
+ )?;
+ let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?;
+ Ok(Self {
+ word_embeddings,
+ blocks,
+ ln_f,
+ lm_head,
+ config: cfg,
+ })
+ }
+
+ pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
+ let (b_sz, seq_len) = input_ids.dims2()?;
+ let mut hidden_state = self.word_embeddings.forward(input_ids)?;
+ let past_kv_len = match &self.blocks[0].self_attention.kv_cache {
+ Some((k, _)) => k.dim(1)?,
+ None => 0,
+ };
+ let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?;
+ for block in self.blocks.iter_mut() {
+ hidden_state = block.forward(&hidden_state, &causal_mask, past_kv_len)?;
+ }
+ let hidden_state = self.ln_f.forward(&hidden_state)?;
+ let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?;
+ let logits = self.lm_head.forward(&hidden_state)?.squeeze(1)?;
+ Ok(logits)
+ }
+}
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs
new file mode 100644
index 00000000..eed4df5e
--- /dev/null
+++ b/candle-transformers/src/models/llama.rs
@@ -0,0 +1,446 @@
+use candle::{DType, Device, IndexOp, Result, Tensor, D};
+use candle_nn::{Embedding, Module, VarBuilder};
+use serde::Deserialize;
+use std::collections::HashMap;
+use std::sync::{Arc, Mutex};
+
+pub const MAX_SEQ_LEN: usize = 4096;
+
+#[derive(Deserialize)]
+pub struct LlamaConfig {
+ pub hidden_size: usize,
+ pub intermediate_size: usize,
+ pub vocab_size: usize,
+ pub num_hidden_layers: usize,
+ pub num_attention_heads: usize,
+ pub num_key_value_heads: Option<usize>,
+ pub rms_norm_eps: f64,
+ #[serde(default = "default_rope")]
+ pub rope_theta: f32,
+}
+
+fn default_rope() -> f32 {
+ 10_000.0
+}
+
+impl LlamaConfig {
+ pub fn into_config(self, use_flash_attn: bool) -> Config {
+ Config {
+ hidden_size: self.hidden_size,
+ intermediate_size: self.intermediate_size,
+ vocab_size: self.vocab_size,
+ num_hidden_layers: self.num_hidden_layers,
+ num_attention_heads: self.num_attention_heads,
+ num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
+ rms_norm_eps: self.rms_norm_eps,
+ rope_theta: self.rope_theta,
+ use_flash_attn,
+ }
+ }
+}
+
+pub struct Config {
+ pub hidden_size: usize,
+ pub intermediate_size: usize,
+ pub vocab_size: usize,
+ pub num_hidden_layers: usize,
+ pub num_attention_heads: usize,
+ pub num_key_value_heads: usize,
+ pub use_flash_attn: bool,
+ pub rms_norm_eps: f64,
+ pub rope_theta: f32,
+}
+
+impl Config {
+ pub fn config_7b_v1(use_flash_attn: bool) -> Self {
+ Self {
+ hidden_size: 4096,
+ intermediate_size: 11008,
+ vocab_size: 32000,
+ num_hidden_layers: 32,
+ num_attention_heads: 32,
+ num_key_value_heads: 32,
+ use_flash_attn,
+ rms_norm_eps: 1e-6,
+ rope_theta: 10_000.0,
+ }
+ }
+
+ pub fn config_7b_v2(use_flash_attn: bool) -> Self {
+ Self {
+ hidden_size: 4096,
+ intermediate_size: 11008,
+ vocab_size: 32000,
+ num_hidden_layers: 32,
+ num_attention_heads: 32,
+ num_key_value_heads: 32,
+ use_flash_attn,
+ rms_norm_eps: 1e-5,
+ rope_theta: 10_000.0,
+ }
+ }
+}
+
+// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
+// model.
+#[derive(Debug)]
+pub struct Linear {
+ inner: candle_nn::Linear,
+ span: tracing::Span,
+}
+
+impl Linear {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
+
+#[derive(Clone)]
+pub struct Cache {
+ masks: Arc<Mutex<HashMap<usize, Tensor>>>,
+ pub use_kv_cache: bool,
+ #[allow(clippy::type_complexity)]
+ kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
+ cos: Tensor,
+ sin: Tensor,
+ device: Device,
+}
+
+impl Cache {
+ pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
+ // precompute freqs_cis
+ let n_elem = config.hidden_size / config.num_attention_heads;
+ let theta: Vec<_> = (0..n_elem)
+ .step_by(2)
+ .map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
+ .collect();
+ let theta = Tensor::new(theta.as_slice(), device)?;
+ let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
+ .to_dtype(DType::F32)?
+ .reshape((MAX_SEQ_LEN, 1))?
+ .matmul(&theta.reshape((1, theta.elem_count()))?)?;
+ // This is different from the paper, see:
+ // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
+ let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
+ let cos = idx_theta.cos()?.to_dtype(dtype)?;
+ let sin = idx_theta.sin()?.to_dtype(dtype)?;
+ Ok(Self {
+ masks: Arc::new(Mutex::new(HashMap::new())),
+ use_kv_cache,
+ kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
+ device: device.clone(),
+ cos,
+ sin,
+ })
+ }
+
+ fn mask(&self, t: usize) -> Result<Tensor> {
+ let mut masks = self.masks.lock().unwrap();
+ if let Some(mask) = masks.get(&t) {
+ Ok(mask.clone())
+ } else {
+ let mask: Vec<_> = (0..t)
+ .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
+ .collect();
+ let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
+ masks.insert(t, mask.clone());
+ Ok(mask)
+ }
+ }
+}
+
+fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
+ Ok(Linear { inner, span })
+}
+
+fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
+ let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
+ Ok(Embedding::new(embeddings, cfg.hidden_size))
+}
+
+struct RmsNorm {
+ inner: candle_nn::RmsNorm,
+ span: tracing::Span,
+}
+
+impl RmsNorm {
+ fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
+ let inner = candle_nn::rms_norm(size, eps, vb)?;
+ Ok(Self { inner, span })
+ }
+
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
+
+struct CausalSelfAttention {
+ q_proj: Linear,
+ k_proj: Linear,
+ v_proj: Linear,
+ o_proj: Linear,
+ num_attention_heads: usize,
+ num_key_value_heads: usize,
+ head_dim: usize,
+ cache: Cache,
+ use_flash_attn: bool,
+ span: tracing::Span,
+ span_rot: tracing::Span,
+}
+
+#[cfg(feature = "flash-attn")]
+fn flash_attn(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ softmax_scale: f32,
+ causal: bool,
+) -> Result<Tensor> {
+ candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
+}
+
+#[cfg(not(feature = "flash-attn"))]
+fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
+ unimplemented!("compile with '--features flash-attn'")
+}
+
+impl CausalSelfAttention {
+ fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
+ let _enter = self.span_rot.enter();
+ let (b_sz, _, seq_len, hidden_size) = x.dims4()?;
+ let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
+ let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
+ let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
+ let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
+ let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
+ let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
+ let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
+ let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
+ Ok(rope)
+ }
+
+ fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let (b_sz, seq_len, hidden_size) = x.dims3()?;
+ let q = self.q_proj.forward(x)?;
+ let k = self.k_proj.forward(x)?;
+ let v = self.v_proj.forward(x)?;
+
+ let q = q
+ .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
+ .transpose(1, 2)?;
+ let k = k
+ .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
+ .transpose(1, 2)?;
+ let mut v = v
+ .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
+ .transpose(1, 2)?;
+
+ let q = self.apply_rotary_emb(&q, index_pos)?;
+ let mut k = self.apply_rotary_emb(&k, index_pos)?;
+
+ if self.cache.use_kv_cache {
+ let mut cache = self.cache.kvs.lock().unwrap();
+ if let Some((cache_k, cache_v)) = &cache[block_idx] {
+ k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
+ v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
+ let k_seq_len = k.dims()[1];
+ if k_seq_len > MAX_SEQ_LEN {
+ k = k
+ .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
+ .contiguous()?
+ }
+ let v_seq_len = v.dims()[1];
+ if v_seq_len > 2 * MAX_SEQ_LEN {
+ v = v
+ .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
+ .contiguous()?
+ }
+ }
+ cache[block_idx] = Some((k.clone(), v.clone()))
+ }
+
+ let k = self.repeat_kv(k)?;
+ let v = self.repeat_kv(v)?;
+
+ let y = if self.use_flash_attn {
+ // flash-attn expects (b_sz, seq_len, nheads, head_dim)
+ let q = q.transpose(1, 2)?;
+ let k = k.transpose(1, 2)?;
+ let v = v.transpose(1, 2)?;
+ let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
+ flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)?
+ } else {
+ let in_dtype = q.dtype();
+ let q = q.to_dtype(DType::F32)?;
+ let k = k.to_dtype(DType::F32)?;
+ let v = v.to_dtype(DType::F32)?;
+ let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
+ let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
+ let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
+ let att = candle_nn::ops::softmax(&att, D::Minus1)?;
+ // Convert to contiguous as matmul doesn't support strided vs for now.
+ att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
+ };
+ let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
+ let y = self.o_proj.forward(&y)?;
+ Ok(y)
+ }
+
+ fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
+ let n_rep = self.num_attention_heads / self.num_key_value_heads;
+ if n_rep == 1 {
+ Ok(x)
+ } else {
+ let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
+ let x = x
+ .unsqueeze(2)?
+ .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
+ .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
+ Ok(x)
+ }
+ }
+
+ fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "attn");
+ let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
+ let size_in = cfg.hidden_size;
+ let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
+ let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
+ let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
+ let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
+ let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
+ let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
+ Ok(Self {
+ q_proj,
+ k_proj,
+ v_proj,
+ o_proj,
+ num_attention_heads: cfg.num_attention_heads,
+ num_key_value_heads: cfg.num_key_value_heads,
+ head_dim: cfg.hidden_size / cfg.num_attention_heads,
+ cache: cache.clone(),
+ use_flash_attn: cfg.use_flash_attn,
+ span,
+ span_rot,
+ })
+ }
+}
+
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
+ let shape = mask.shape();
+ let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
+ let m = mask.where_cond(&on_true, on_false)?;
+ Ok(m)
+}
+
+struct Mlp {
+ c_fc1: Linear,
+ c_fc2: Linear,
+ c_proj: Linear,
+ span: tracing::Span,
+}
+
+impl Mlp {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
+ self.c_proj.forward(&x)
+ }
+
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "mlp");
+ let h_size = cfg.hidden_size;
+ let i_size = cfg.intermediate_size;
+ let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
+ let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
+ let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
+ Ok(Self {
+ c_fc1,
+ c_fc2,
+ c_proj,
+ span,
+ })
+ }
+}
+
+struct Block {
+ rms_1: RmsNorm,
+ attn: CausalSelfAttention,
+ rms_2: RmsNorm,
+ mlp: Mlp,
+ span: tracing::Span,
+}
+
+impl Block {
+ fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let residual = x;
+ let x = self.rms_1.forward(x)?;
+ let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
+ let residual = &x;
+ let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
+ Ok(x)
+ }
+
+ fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "block");
+ let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
+ let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
+ let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
+ let rms_2 = RmsNorm::load(
+ cfg.hidden_size,
+ cfg.rms_norm_eps,
+ vb.pp("post_attention_layernorm"),
+ )?;
+ Ok(Self {
+ rms_1,
+ attn,
+ rms_2,
+ mlp,
+ span,
+ })
+ }
+}
+
+pub struct Llama {
+ wte: Embedding,
+ blocks: Vec<Block>,
+ ln_f: RmsNorm,
+ lm_head: Linear,
+}
+
+impl Llama {
+ pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
+ let (_b_sz, seq_len) = x.dims2()?;
+ let mut x = self.wte.forward(x)?;
+ for (block_idx, block) in self.blocks.iter().enumerate() {
+ x = block.forward(&x, index_pos, block_idx)?;
+ }
+ let x = self.ln_f.forward(&x)?;
+ let x = x.i((.., seq_len - 1, ..))?;
+ let logits = self.lm_head.forward(&x)?;
+ logits.to_dtype(DType::F32)
+ }
+
+ pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
+ let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
+ let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
+ let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
+ let blocks: Vec<_> = (0..cfg.num_hidden_layers)
+ .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
+ .collect();
+
+ Ok(Self {
+ wte,
+ blocks,
+ ln_f,
+ lm_head,
+ })
+ }
+}
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 8b137891..d783a2c6 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -1 +1,13 @@
-
+pub mod bert;
+pub mod bigcode;
+pub mod dinov2;
+pub mod efficientnet;
+pub mod falcon;
+pub mod llama;
+pub mod quantized_llama;
+pub mod quantized_t5;
+pub mod segment_anything;
+pub mod stable_diffusion;
+pub mod t5;
+pub mod whisper;
+pub mod wuerstchen;
diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs
new file mode 100644
index 00000000..2988b0fb
--- /dev/null
+++ b/candle-transformers/src/models/quantized_llama.rs
@@ -0,0 +1,371 @@
+use std::collections::HashMap;
+
+use candle::quantized::QTensor;
+use candle::quantized::{ggml_file, gguf_file};
+use candle::{DType, Device, IndexOp, Result, Tensor, D};
+use candle_nn::{Embedding, Module};
+
+pub const MAX_SEQ_LEN: usize = 4096;
+
+struct RmsNorm {
+ inner: candle_nn::LayerNorm,
+ span: tracing::Span,
+}
+
+impl RmsNorm {
+ fn new(scale: QTensor, eps: f32) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
+ let scale = scale.dequantize(&Device::Cpu)?;
+ let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
+ Ok(Self { inner, span })
+ }
+
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
+
+// QMatMul wrapper adding some tracing.
+struct QMatMul {
+ inner: candle::quantized::QMatMul,
+ span: tracing::Span,
+}
+
+impl QMatMul {
+ fn from_qtensor(qtensor: QTensor) -> Self {
+ let inner = candle::quantized::QMatMul::from_qtensor(qtensor);
+ let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
+ Self { inner, span }
+ }
+
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(xs)
+ }
+}
+
+struct LayerWeights {
+ attention_wq: QMatMul,
+ attention_wk: QMatMul,
+ attention_wv: QMatMul,
+ attention_wo: QMatMul,
+ attention_norm: RmsNorm,
+ feed_forward_w1: QMatMul,
+ feed_forward_w2: QMatMul,
+ feed_forward_w3: QMatMul,
+ ffn_norm: RmsNorm,
+ n_head: usize,
+ n_kv_head: usize,
+ head_dim: usize,
+ cos: Tensor,
+ sin: Tensor,
+ kv_cache: Option<(Tensor, Tensor)>,
+ span_attn: tracing::Span,
+ span_rot: tracing::Span,
+ span_mlp: tracing::Span,
+}
+
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
+ let shape = mask.shape();
+ let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
+ let m = mask.where_cond(&on_true, on_false)?;
+ Ok(m)
+}
+
+impl LayerWeights {
+ fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
+ let _enter = self.span_rot.enter();
+ let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
+ let cos = self
+ .cos
+ .narrow(0, index_pos, seq_len)?
+ .reshape((seq_len, n_embd / 2, 1))?;
+ let sin = self
+ .sin
+ .narrow(0, index_pos, seq_len)?
+ .reshape((seq_len, n_embd / 2, 1))?;
+ let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
+ let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
+ // This mimics the llama.cpp behavior.
+ // https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105
+ // The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
+ // The resulting y0 and y1 are also interleaved with:
+ // y0 = x0*cos - x1*sin
+ // y1 = x0*sin + x1*cos
+ let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
+ let x0 = x.narrow(D::Minus1, 0, 1)?;
+ let x1 = x.narrow(D::Minus1, 1, 1)?;
+ let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
+ let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
+ let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
+ let rope = rope.flatten_from(D::Minus2)?;
+ Ok(rope)
+ }
+
+ fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> {
+ let _enter = self.span_attn.enter();
+ let (b_sz, seq_len, n_embd) = x.dims3()?;
+ let q = self.attention_wq.forward(x)?;
+ let k = self.attention_wk.forward(x)?;
+ let v = self.attention_wv.forward(x)?;
+
+ let q = q
+ .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
+ .transpose(1, 2)?;
+ let k = k
+ .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
+ .transpose(1, 2)?;
+ let v = v
+ .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
+ .transpose(1, 2)?;
+
+ let q = self.apply_rotary_emb(&q, index_pos)?;
+ let k = self.apply_rotary_emb(&k, index_pos)?;
+
+ let (k, v) = match &self.kv_cache {
+ None => (k, v),
+ Some((k_cache, v_cache)) => {
+ if index_pos == 0 {
+ (k, v)
+ } else {
+ let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
+ let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
+ (k, v)
+ }
+ }
+ };
+ self.kv_cache = Some((k.clone(), v.clone()));
+
+ // Support for MQA, useful for 70B models.
+ let k = self.repeat_kv(k)?;
+ let v = self.repeat_kv(v)?;
+
+ let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
+ let mask = mask.broadcast_as(att.shape())?;
+ let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
+ let att = candle_nn::ops::softmax_last_dim(&att)?;
+ // Convert to contiguous as matmul doesn't support strided vs for now.
+ let y = att.matmul(&v.contiguous()?)?;
+ let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
+ let y = self.attention_wo.forward(&y)?;
+ Ok(y)
+ }
+
+ fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
+ let n_rep = self.n_head / self.n_kv_head;
+ if n_rep == 1 {
+ Ok(x)
+ } else {
+ let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
+ let x = x
+ .unsqueeze(2)?
+ .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
+ .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
+ Ok(x)
+ }
+ }
+}
+
+pub struct ModelWeights {
+ tok_embeddings: Embedding,
+ layers: Vec<LayerWeights>,
+ norm: RmsNorm,
+ output: QMatMul,
+ masks: HashMap<usize, Tensor>,
+ span: tracing::Span,
+ span_output: tracing::Span,
+}
+
+fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> {
+ let theta: Vec<_> = (0..head_dim)
+ .step_by(2)
+ .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
+ .collect();
+ let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
+ let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
+ .to_dtype(DType::F32)?
+ .reshape((MAX_SEQ_LEN, 1))?
+ .matmul(&theta.reshape((1, theta.elem_count()))?)?;
+ let cos = idx_theta.cos()?;
+ let sin = idx_theta.sin()?;
+ Ok((cos, sin))
+}
+
+impl ModelWeights {
+ pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
+ let cpu = &Device::Cpu;
+ let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
+ let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?;
+ let tok_embeddings = ct.remove("tok_embeddings.weight")?;
+ let tok_embeddings = tok_embeddings.dequantize(cpu)?;
+ let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
+ let output = ct.remove("output.weight")?;
+ let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
+ for layer_idx in 0..ct.hparams.n_layer {
+ let prefix = format!("layers.{layer_idx}");
+ let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?;
+ let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
+ let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
+ let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
+ let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
+ let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
+ let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
+ let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
+ let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
+ let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
+ let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
+ let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
+ layers.push(LayerWeights {
+ attention_wq: QMatMul::from_qtensor(attention_wq),
+ attention_wk: QMatMul::from_qtensor(attention_wk),
+ attention_wv: QMatMul::from_qtensor(attention_wv),
+ attention_wo: QMatMul::from_qtensor(attention_wo),
+ attention_norm: RmsNorm::new(attention_norm, 1e-5)?,
+ feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
+ feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
+ feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
+ ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?,
+ n_head: ct.hparams.n_head as usize,
+ n_kv_head: ct.hparams.n_head as usize / gqa,
+ head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
+ cos: cos.clone(),
+ sin: sin.clone(),
+ kv_cache: None,
+ span_attn,
+ span_rot,
+ span_mlp,
+ })
+ }
+ let span = tracing::span!(tracing::Level::TRACE, "model");
+ let span_output = tracing::span!(tracing::Level::TRACE, "output");
+ Ok(Self {
+ tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
+ layers,
+ norm,
+ output: QMatMul::from_qtensor(output),
+ masks: HashMap::new(),
+ span,
+ span_output,
+ })
+ }
+
+ pub fn from_gguf<R: std::io::Seek + std::io::Read>(
+ ct: gguf_file::Content,
+ reader: &mut R,
+ ) -> Result<Self> {
+ let cpu = &Device::Cpu;
+ let md_get = |s: &str| match ct.metadata.get(s) {
+ None => candle::bail!("cannot find {s} in metadata"),
+ Some(v) => Ok(v),
+ };
+
+ // Parameter extraction from metadata.
+ let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
+ let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
+ let block_count = md_get("llama.block_count")?.to_u32()? as usize;
+ let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
+ let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
+ // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
+ let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?;
+
+ let rope_freq_base = md_get("llama.rope.freq_base")
+ .and_then(|m| m.to_f32())
+ .unwrap_or(10000f32);
+ let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
+
+ let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
+ let tok_embeddings = tok_embeddings.dequantize(cpu)?;
+ let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?;
+ let output = ct.tensor(reader, "output.weight")?;
+ let mut layers = Vec::with_capacity(block_count);
+ for layer_idx in 0..block_count {
+ let prefix = format!("blk.{layer_idx}");
+ let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?;
+ let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
+ let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
+ let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
+ let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
+ let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
+ let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
+ let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
+ let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
+ let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
+ let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
+ let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
+ layers.push(LayerWeights {
+ attention_wq: QMatMul::from_qtensor(attention_wq),
+ attention_wk: QMatMul::from_qtensor(attention_wk),
+ attention_wv: QMatMul::from_qtensor(attention_wv),
+ attention_wo: QMatMul::from_qtensor(attention_wo),
+ attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?,
+ feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1),
+ feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2),
+ feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3),
+ ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?,
+ n_head: head_count,
+ n_kv_head: head_count_kv,
+ head_dim: embedding_length / head_count,
+ cos: cos.clone(),
+ sin: sin.clone(),
+ kv_cache: None,
+ span_attn,
+ span_rot,
+ span_mlp,
+ })
+ }
+ let span = tracing::span!(tracing::Level::TRACE, "model");
+ let span_output = tracing::span!(tracing::Level::TRACE, "output");
+ Ok(Self {
+ tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
+ layers,
+ norm,
+ output: QMatMul::from_qtensor(output),
+ masks: HashMap::new(),
+ span,
+ span_output,
+ })
+ }
+
+ fn mask(&mut self, t: usize) -> Result<Tensor> {
+ if let Some(mask) = self.masks.get(&t) {
+ Ok(mask.clone())
+ } else {
+ let mask: Vec<_> = (0..t)
+ .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
+ .collect();
+ let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
+ self.masks.insert(t, mask.clone());
+ Ok(mask)
+ }
+ }
+
+ pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
+ let (_b_sz, seq_len) = x.dims2()?;
+ let mask = self.mask(seq_len)?;
+ let _enter = self.span.enter();
+ let mut layer_in = self.tok_embeddings.forward(x)?;
+ for layer in self.layers.iter_mut() {
+ let x = layer_in;
+ let residual = &x;
+ let x = layer.attention_norm.forward(&x)?;
+ let attn = layer.forward_attn(&x, &mask, index_pos)?;
+ let x = (attn + residual)?;
+
+ // MLP
+ let _enter = layer.span_mlp.enter();
+ let residual = &x;
+ let x = layer.ffn_norm.forward(&x)?;
+ let w1 = layer.feed_forward_w1.forward(&x)?;
+ let w3 = layer.feed_forward_w3.forward(&x)?;
+ let mlp = layer
+ .feed_forward_w2
+ .forward(&(candle_nn::ops::silu(&w1)? * w3)?)?;
+ layer_in = (mlp + residual)?;
+ }
+ let x = self.norm.forward(&layer_in)?;
+ let x = x.i((.., seq_len - 1, ..))?;
+ let _enter = self.span_output.enter();
+ self.output.forward(&x)
+ }
+}
diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs
new file mode 100644
index 00000000..a10c3b80
--- /dev/null
+++ b/candle-transformers/src/models/quantized_t5.rs
@@ -0,0 +1,884 @@
+// T5 Text Model, quantized version
+// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
+
+use candle::quantized::QTensor;
+use candle::{DType, Device, Module, Result, Shape, Tensor, D};
+use candle_nn::Activation;
+use serde::Deserialize;
+use std::sync::Arc;
+
+// VarBuilder specialized for QTensors
+pub struct VarBuilder {
+ data: Arc<std::collections::HashMap<String, Arc<QTensor>>>,
+ path: Vec<String>,
+ device: Device,
+}
+
+impl VarBuilder {
+ pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
+ let mut file = std::fs::File::open(p)?;
+ let content = candle::quantized::gguf_file::Content::read(&mut file)?;
+ let mut data = std::collections::HashMap::new();
+ for tensor_name in content.tensor_infos.keys() {
+ let tensor = content.tensor(&mut file, tensor_name)?;
+ data.insert(tensor_name.to_string(), Arc::new(tensor));
+ }
+ Ok(Self {
+ data: Arc::new(data),
+ path: Vec::new(),
+ device: Device::Cpu,
+ })
+ }
+
+ fn pp<S: ToString>(&self, s: S) -> Self {
+ let mut path = self.path.clone();
+ path.push(s.to_string());
+ Self {
+ data: self.data.clone(),
+ path,
+ device: self.device.clone(),
+ }
+ }
+
+ fn path(&self, tensor_name: &str) -> String {
+ if self.path.is_empty() {
+ tensor_name.to_string()
+ } else {
+ [&self.path.join("."), tensor_name].join(".")
+ }
+ }
+
+ fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Arc<QTensor>> {
+ let path = self.path(name);
+ match self.data.get(&path) {
+ None => {
+ candle::bail!("cannot find tensor {name}")
+ }
+ Some(qtensor) => {
+ let shape = s.into();
+ if qtensor.shape() != &shape {
+ candle::bail!(
+ "shape mismatch for {name}, got {:?}, expected {shape:?}",
+ qtensor.shape()
+ )
+ }
+ Ok(qtensor.clone())
+ }
+ }
+ }
+}
+
+#[derive(Debug)]
+struct Embedding {
+ inner: candle_nn::Embedding,
+ span: tracing::Span,
+}
+
+impl Embedding {
+ fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
+ let embeddings = vb.get((d1, d2), "weight")?.dequantize(&vb.device)?;
+ let inner = candle_nn::Embedding::new(embeddings, d2);
+ let span = tracing::span!(tracing::Level::TRACE, "embedding");
+ Ok(Self { inner, span })
+ }
+
+ fn embeddings(&self) -> &Tensor {
+ self.inner.embeddings()
+ }
+}
+
+impl Module for Embedding {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(xs)
+ }
+}
+
+// QMatMul wrapper adding some tracing.
+struct QMatMul {
+ inner: candle::quantized::QMatMul,
+ span: tracing::Span,
+}
+
+impl QMatMul {
+ fn new(out_dim: usize, in_dim: usize, vb: VarBuilder) -> Result<Self> {
+ let ws = vb.get((in_dim, out_dim), "weight")?;
+ let inner = candle::quantized::QMatMul::from_arc(ws);
+ let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
+ Ok(Self { inner, span })
+ }
+}
+
+impl Module for QMatMul {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(xs)
+ }
+}
+
+impl std::fmt::Debug for QMatMul {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "QMatMul")
+ }
+}
+
+fn default_relative_attention_max_distance() -> usize {
+ 128
+}
+
+fn default_is_decoder() -> bool {
+ false
+}
+
+fn default_use_cache() -> bool {
+ true
+}
+
+fn default_tie_word_embeddings() -> bool {
+ true
+}
+
+fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
+ let mask: Vec<_> = (0..size)
+ .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
+ .collect();
+ Tensor::from_slice(&mask, (size, size), device)
+}
+
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
+ let shape = mask.shape();
+ let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
+ let m = mask.where_cond(&on_true, on_false)?;
+ Ok(m)
+}
+
+#[derive(Debug, Clone, PartialEq, Deserialize)]
+pub struct Config {
+ vocab_size: usize,
+ d_model: usize,
+ d_kv: usize,
+ d_ff: usize,
+ num_layers: usize,
+ num_decoder_layers: Option<usize>,
+ num_heads: usize,
+ relative_attention_num_buckets: usize,
+ #[serde(default = "default_relative_attention_max_distance")]
+ relative_attention_max_distance: usize,
+ dropout_rate: f64,
+ layer_norm_epsilon: f64,
+ initializer_factor: f64,
+ #[serde(default)]
+ feed_forward_proj: Activation,
+ #[serde(default = "default_tie_word_embeddings")]
+ tie_word_embeddings: bool,
+ #[serde(default = "default_is_decoder")]
+ is_decoder: bool,
+ is_encoder_decoder: bool,
+ #[serde(default = "default_use_cache")]
+ pub use_cache: bool,
+ pub pad_token_id: usize,
+ pub eos_token_id: usize,
+}
+
+impl Default for Config {
+ fn default() -> Self {
+ Self {
+ vocab_size: 32128,
+ d_model: 512,
+ d_kv: 64,
+ d_ff: 2048,
+ num_layers: 6,
+ num_decoder_layers: None,
+ num_heads: 8,
+ relative_attention_num_buckets: 32,
+ relative_attention_max_distance: 128,
+ dropout_rate: 0.1,
+ layer_norm_epsilon: 1e-6,
+ initializer_factor: 1.0,
+ feed_forward_proj: Activation::Relu,
+ tie_word_embeddings: true,
+ is_decoder: false,
+ is_encoder_decoder: true,
+ use_cache: true,
+ pad_token_id: 0,
+ eos_token_id: 1,
+ }
+ }
+}
+
+#[derive(Debug)]
+struct T5LayerNorm {
+ weight: Tensor,
+ variance_epsilon: f64,
+ span: tracing::Span,
+}
+
+impl T5LayerNorm {
+ fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
+ let weight = vb.get(h, "weight")?.dequantize(&vb.device)?;
+ Ok(Self {
+ weight,
+ variance_epsilon: eps,
+ span: tracing::span!(tracing::Level::TRACE, "layer-norm"),
+ })
+ }
+}
+
+impl Module for T5LayerNorm {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let dtype = xs.dtype();
+ let xs_f32 = xs.to_dtype(DType::F32)?;
+ // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
+ let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
+ let xs = xs.to_dtype(dtype)?;
+ let xs = xs.broadcast_mul(&self.weight)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug)]
+struct T5DenseActDense {
+ wi: QMatMul,
+ wo: QMatMul,
+ act: Activation,
+ span: tracing::Span,
+}
+
+impl T5DenseActDense {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let wi = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
+ let wo = QMatMul::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
+ Ok(Self {
+ wi,
+ wo,
+ act: Activation::Relu,
+ span: tracing::span!(tracing::Level::TRACE, "dense-act-dense"),
+ })
+ }
+}
+
+impl Module for T5DenseActDense {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let xs = self.wi.forward(xs)?;
+ let xs = self.act.forward(&xs)?;
+ let xs = self.wo.forward(&xs)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug)]
+struct T5DenseGatedActDense {
+ wi_0: QMatMul,
+ wi_1: QMatMul,
+ wo: QMatMul,
+ act: Activation,
+ span: tracing::Span,
+}
+
+impl T5DenseGatedActDense {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let wi_0 = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
+ let wi_1 = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
+ let wo = QMatMul::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
+ Ok(Self {
+ wi_0,
+ wi_1,
+ wo,
+ act: Activation::NewGelu,
+ span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
+ })
+ }
+}
+
+impl Module for T5DenseGatedActDense {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?;
+ let hidden_linear = self.wi_1.forward(xs)?;
+ let xs = hidden_gelu.broadcast_mul(&hidden_linear)?;
+ let xs = self.wo.forward(&xs)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug)]
+struct T5LayerFF {
+ dense_act: Option<T5DenseActDense>,
+ gated_dense_act: Option<T5DenseGatedActDense>,
+ layer_norm: T5LayerNorm,
+ span: tracing::Span,
+}
+
+impl T5LayerFF {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let layer_norm =
+ T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
+ let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
+ (
+ None,
+ Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
+ )
+ } else {
+ (
+ Some(T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?),
+ None,
+ )
+ };
+ Ok(Self {
+ dense_act,
+ gated_dense_act,
+ layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "layer-ff"),
+ })
+ }
+}
+
+impl Module for T5LayerFF {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let ys = self.layer_norm.forward(xs)?;
+ let ys = match &self.dense_act {
+ Some(dense_act) => dense_act.forward(&ys)?,
+ None => self.gated_dense_act.as_ref().unwrap().forward(&ys)?,
+ };
+ let xs = (xs + ys)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug)]
+struct T5Attention {
+ q: QMatMul,
+ k: QMatMul,
+ v: QMatMul,
+ o: QMatMul,
+ n_heads: usize,
+ d_kv: usize,
+ relative_attention_bias: Option<Embedding>,
+ relative_attention_num_buckets: usize,
+ relative_attention_max_distance: usize,
+ inner_dim: usize,
+ use_cache: bool,
+ kv_cache: Option<(Tensor, Tensor)>,
+ span: tracing::Span,
+ span_cache: tracing::Span,
+ span_mm: tracing::Span,
+ span_sm: tracing::Span,
+}
+
+impl T5Attention {
+ fn load(
+ has_relative_attention_bias: bool,
+ decoder: bool,
+ vb: VarBuilder,
+ cfg: &Config,
+ ) -> Result<Self> {
+ let inner_dim = cfg.num_heads * cfg.d_kv;
+ let q = QMatMul::new(cfg.d_model, inner_dim, vb.pp("q"))?;
+ let k = QMatMul::new(cfg.d_model, inner_dim, vb.pp("k"))?;
+ let v = QMatMul::new(cfg.d_model, inner_dim, vb.pp("v"))?;
+ let o = QMatMul::new(inner_dim, cfg.d_model, vb.pp("o"))?;
+ let relative_attention_bias = if has_relative_attention_bias {
+ let emb = Embedding::new(
+ cfg.relative_attention_num_buckets,
+ cfg.num_heads,
+ vb.pp("relative_attention_bias"),
+ )?;
+ Some(emb)
+ } else {
+ None
+ };
+ Ok(Self {
+ q,
+ k,
+ v,
+ o,
+ n_heads: cfg.num_heads,
+ d_kv: cfg.d_kv,
+ relative_attention_bias,
+ relative_attention_num_buckets: cfg.relative_attention_num_buckets,
+ relative_attention_max_distance: cfg.relative_attention_max_distance,
+ inner_dim,
+ use_cache: cfg.use_cache && decoder,
+ kv_cache: None,
+ span: tracing::span!(tracing::Level::TRACE, "attention"),
+ span_cache: tracing::span!(tracing::Level::TRACE, "attention-cache"),
+ span_mm: tracing::span!(tracing::Level::TRACE, "attention-mm"),
+ span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ position_bias: Option<&Tensor>,
+ key_value_states: Option<&Tensor>,
+ mask: Option<&Tensor>,
+ ) -> Result<(Tensor, Option<Tensor>)> {
+ // Performs Self-attention (if key_value_states is None) or attention
+ // over source sentence (provided by key_value_states).
+ let _enter = self.span.enter();
+ let kv_input = match key_value_states {
+ None => xs,
+ Some(key_value_states) => key_value_states,
+ };
+ let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?);
+ let kv_len = kv_input.dim(1)?;
+ let q = self.q.forward(xs)?;
+ let k = self.k.forward(kv_input)?;
+ let v = self.v.forward(kv_input)?;
+ let q = q
+ .reshape((b_sz, q_len, self.n_heads, self.d_kv))?
+ .transpose(1, 2)?
+ .contiguous()?;
+ let mut k = k
+ .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
+ .transpose(1, 2)?
+ .contiguous()?;
+ let mut v = v
+ .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
+ .transpose(1, 2)?
+ .contiguous()?;
+
+ if self.use_cache {
+ let _enter = self.span_cache.enter();
+ if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
+ k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?;
+ v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?;
+ };
+ self.kv_cache = Some((k.clone(), v.clone()));
+ };
+ // TODO: Use flash_attn.
+ let scores = {
+ let _enter = self.span_mm.enter();
+ q.matmul(&k.t()?)?
+ };
+ let scores = match mask {
+ None => scores,
+ Some(mask) => masked_fill(
+ &scores,
+ &mask
+ .unsqueeze(0)?
+ .unsqueeze(0)?
+ .repeat((b_sz, self.n_heads))?,
+ f32::NEG_INFINITY,
+ )?,
+ };
+
+ let (scores, position_bias) = match position_bias {
+ Some(position_bias) => (
+ scores.broadcast_add(position_bias)?,
+ Some(position_bias.clone()),
+ ),
+ None => match &self.relative_attention_bias {
+ None => (scores, None),
+ Some(relative_attention_bias) => {
+ // This only handles the bidirectional case.
+ let kv_len = k.dim(2)?;
+ let (q_start, q_end) = match self.use_cache {
+ true => ((kv_len - q_len) as u32, kv_len as u32),
+ false => (0_u32, kv_len as u32),
+ };
+ let num_buckets = self.relative_attention_num_buckets as u32 / 2;
+ let max_exact = num_buckets / 2;
+ let relative_position = (q_start..q_end)
+ .map(|i| {
+ (0..kv_len as u32)
+ .map(|j| {
+ if i < j {
+ if j - i < max_exact {
+ j - i + num_buckets
+ } else {
+ let b = f32::log(
+ (j - i) as f32 / max_exact as f32,
+ self.relative_attention_max_distance as f32
+ / max_exact as f32,
+ ) * (num_buckets - max_exact) as f32;
+ u32::min(
+ max_exact + num_buckets + b as u32,
+ self.relative_attention_num_buckets as u32 - 1,
+ )
+ }
+ } else if i - j < max_exact {
+ i - j
+ } else {
+ let b = f32::log(
+ (i - j) as f32 / max_exact as f32,
+ self.relative_attention_max_distance as f32
+ / max_exact as f32,
+ ) * (num_buckets - max_exact) as f32;
+ max_exact + b as u32
+ }
+ })
+ .collect::<Vec<u32>>()
+ })
+ .collect::<Vec<Vec<_>>>();
+ let relative_buckets = Tensor::new(relative_position, q.device())?;
+ let position_bias = relative_attention_bias
+ .forward(&relative_buckets)?
+ .permute((2, 0, 1))?
+ .unsqueeze(0)?;
+ (scores.broadcast_add(&position_bias)?, Some(position_bias))
+ // TODO: position_bias_masked?
+ }
+ },
+ };
+
+ let attn_weights = {
+ let _enter = self.span_sm.enter();
+ candle_nn::ops::softmax(&scores, D::Minus1)?
+ };
+ let attn_output = attn_weights.matmul(&v)?;
+ let attn_output = attn_output
+ .transpose(1, 2)?
+ .reshape((b_sz, q_len, self.inner_dim))?;
+ let attn_output = self.o.forward(&attn_output)?;
+ Ok((attn_output, position_bias))
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.kv_cache = None
+ }
+}
+
+#[derive(Debug)]
+struct T5LayerSelfAttention {
+ self_attention: T5Attention,
+ layer_norm: T5LayerNorm,
+ span: tracing::Span,
+}
+
+impl T5LayerSelfAttention {
+ fn load(h: bool, d: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?;
+ let layer_norm =
+ T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
+ Ok(Self {
+ self_attention,
+ layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "self-attn"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ position_bias: Option<&Tensor>,
+ mask: Option<&Tensor>,
+ ) -> Result<(Tensor, Option<Tensor>)> {
+ let _enter = self.span.enter();
+ let normed_xs = self.layer_norm.forward(xs)?;
+ let (ys, position_bias) =
+ self.self_attention
+ .forward(&normed_xs, position_bias, None, mask)?;
+ let ys = (xs + ys)?;
+ Ok((ys, position_bias))
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.self_attention.clear_kv_cache()
+ }
+}
+
+#[derive(Debug)]
+struct T5LayerCrossAttention {
+ cross_attention: T5Attention,
+ layer_norm: T5LayerNorm,
+ span: tracing::Span,
+}
+
+impl T5LayerCrossAttention {
+ fn load(decoder: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?;
+ let layer_norm =
+ T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
+ Ok(Self {
+ cross_attention,
+ layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "cross-attn"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ hidden_states: &Tensor,
+ position_bias: Option<&Tensor>,
+ key_value_states: &Tensor,
+ ) -> Result<(Tensor, Option<Tensor>)> {
+ let _enter = self.span.enter();
+ let normed_hidden_states = self.layer_norm.forward(hidden_states)?;
+ let (ys, position_bias) = self.cross_attention.forward(
+ &normed_hidden_states,
+ position_bias,
+ Some(key_value_states),
+ None,
+ )?;
+ let ys = (hidden_states + ys)?;
+ Ok((ys, position_bias))
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.cross_attention.clear_kv_cache()
+ }
+}
+
+#[derive(Debug)]
+struct T5Block {
+ self_attn: T5LayerSelfAttention,
+ cross_attn: Option<T5LayerCrossAttention>,
+ ff: T5LayerFF,
+ span: tracing::Span,
+}
+
+impl T5Block {
+ fn load(
+ has_relative_attention_bias: bool,
+ decoder: bool,
+ vb: VarBuilder,
+ cfg: &Config,
+ ) -> Result<Self> {
+ let vb = vb.pp("layer");
+ let self_attn =
+ T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?;
+ let cross_attn = if cfg.is_decoder {
+ Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?)
+ } else {
+ None
+ };
+ let ff_i = if cross_attn.is_some() { 2 } else { 1 };
+ let ff = T5LayerFF::load(vb.pp(ff_i), cfg)?;
+ Ok(Self {
+ self_attn,
+ cross_attn,
+ ff,
+ span: tracing::span!(tracing::Level::TRACE, "block"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ position_bias: Option<&Tensor>,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<(Tensor, Option<Tensor>)> {
+ let _enter = self.span.enter();
+ // TODO: Cache masks
+ let mask = match self.cross_attn.is_some() {
+ true => {
+ let mask_len = xs.dim(1)?;
+ // If the input seq length is 1, no need for a mask, this is also helpful to avoid shape
+ // issues when using the KV cache in the decoder.
+ if mask_len <= 1 {
+ None
+ } else {
+ Some(get_mask(mask_len, xs.device())?)
+ }
+ }
+ false => None,
+ };
+ let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?;
+ // TODO: clamp for f16?
+ if let Some(cross_attn) = &mut self.cross_attn {
+ (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?;
+ // TODO: clamp for f16?
+ }
+ let xs = self.ff.forward(&xs)?;
+ // TODO: clamp for f16?
+ Ok((xs, position_bias))
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.self_attn.clear_kv_cache();
+ self.cross_attn.iter_mut().for_each(|c| c.clear_kv_cache());
+ }
+}
+
+#[derive(Debug)]
+struct T5Stack {
+ block: Vec<T5Block>,
+ shared: Arc<Embedding>,
+ final_layer_norm: T5LayerNorm,
+ span: tracing::Span,
+}
+
+impl T5Stack {
+ fn load(decoder: bool, vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
+ let block = (0..cfg.num_layers)
+ .map(|i| T5Block::load(i == 0, decoder, vb.pp(format!("block.{i}")), cfg))
+ .collect::<Result<Vec<_>>>()?;
+ let final_layer_norm = T5LayerNorm::load(
+ cfg.d_model,
+ cfg.layer_norm_epsilon,
+ vb.pp("final_layer_norm"),
+ )?;
+ Ok(Self {
+ block,
+ shared: shared.clone(),
+ final_layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "stack"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ input_ids: &Tensor,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let input_embeds = self.shared.as_ref().forward(input_ids)?;
+ let mut hidden_states = input_embeds;
+ let mut position_bias = None;
+ for block in self.block.iter_mut() {
+ (hidden_states, position_bias) = block.forward(
+ &hidden_states,
+ position_bias.as_ref(),
+ encoder_hidden_states,
+ )?
+ }
+ self.final_layer_norm.forward(&hidden_states)
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.block.iter_mut().for_each(|b| b.clear_kv_cache())
+ }
+}
+
+#[derive(Debug)]
+pub struct T5EncoderModel {
+ encoder: T5Stack,
+ device: Device,
+ span: tracing::Span,
+}
+
+impl T5EncoderModel {
+ pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
+ let shared = Arc::new(shared);
+ let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
+ Ok(Self {
+ encoder,
+ device: vb.device.clone(),
+ span: tracing::span!(tracing::Level::TRACE, "encoder"),
+ })
+ }
+
+ pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.encoder.forward(input_ids, None)
+ }
+
+ pub fn device(&self) -> &Device {
+ &self.device
+ }
+
+ pub fn clear_kv_cache(&mut self) {
+ self.encoder.clear_kv_cache()
+ }
+}
+
+#[derive(Debug)]
+pub struct T5ForConditionalGeneration {
+ encoder: T5Stack,
+ decoder: T5Stack,
+ d_model: usize,
+ tie_word_embeddings: bool,
+ lm_head: Option<QMatMul>,
+ shared: Arc<Embedding>,
+ device: Device,
+ span_decode: tracing::Span,
+ span_decode_head: tracing::Span,
+}
+
+impl T5ForConditionalGeneration {
+ pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ assert!(cfg.is_encoder_decoder);
+ let d_model = cfg.d_model;
+ let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
+ let shared = Arc::new(shared);
+
+ let mut encoder_cfg = cfg.clone();
+ encoder_cfg.is_decoder = false;
+ encoder_cfg.use_cache = false;
+ encoder_cfg.is_encoder_decoder = false;
+ let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, &encoder_cfg)?;
+
+ let mut decoder_cfg = cfg.clone();
+ decoder_cfg.is_decoder = true;
+ decoder_cfg.is_encoder_decoder = false;
+ decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
+ let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?;
+
+ let tie_word_embeddings = cfg.tie_word_embeddings;
+ let lm_head = if tie_word_embeddings {
+ None
+ } else {
+ Some(QMatMul::new(cfg.d_model, cfg.vocab_size, vb.pp("lm_head"))?)
+ };
+
+ Ok(Self {
+ encoder,
+ decoder,
+ d_model,
+ tie_word_embeddings,
+ lm_head,
+ shared,
+ device: vb.device.clone(),
+ span_decode: tracing::span!(tracing::Level::TRACE, "decode"),
+ span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"),
+ })
+ }
+
+ pub fn encode(&mut self, input_ids: &Tensor) -> Result<Tensor> {
+ self.encoder.forward(input_ids, None)
+ }
+
+ pub fn decode(
+ &mut self,
+ decoder_input_ids: &Tensor,
+ encoder_output: &Tensor,
+ ) -> Result<Tensor> {
+ let _enter = self.span_decode.enter();
+ let decoder_output = self
+ .decoder
+ .forward(decoder_input_ids, Some(encoder_output))?;
+
+ let scaling_factor = if self.tie_word_embeddings {
+ // Rescale output before projecting on vocab
+ // See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
+ (self.d_model as f64).sqrt()
+ } else {
+ 1.0
+ };
+ let sequence_output = ((decoder_output
+ .narrow(1, decoder_output.dim(1)? - 1, 1)?
+ .squeeze(1)?)
+ * scaling_factor)?;
+ let output = {
+ let _enter = self.span_decode_head.enter();
+ match self.lm_head {
+ None => sequence_output.matmul(&self.shared.embeddings().t()?)?,
+ Some(ref lm_head) => lm_head.forward(&sequence_output)?,
+ }
+ };
+
+ // TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5)
+ Ok(output)
+ }
+
+ pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
+ let encoder_output = self.encode(input_ids)?;
+ self.decode(decoder_input_ids, &encoder_output)
+ }
+
+ pub fn device(&self) -> &Device {
+ &self.device
+ }
+
+ pub fn clear_kv_cache(&mut self) {
+ self.encoder.clear_kv_cache();
+ self.decoder.clear_kv_cache();
+ }
+}
diff --git a/candle-transformers/src/models/segment_anything/image_encoder.rs b/candle-transformers/src/models/segment_anything/image_encoder.rs
new file mode 100644
index 00000000..0b313830
--- /dev/null
+++ b/candle-transformers/src/models/segment_anything/image_encoder.rs
@@ -0,0 +1,483 @@
+use candle::{DType, IndexOp, Result, Tensor};
+use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
+
+#[derive(Debug)]
+struct PatchEmbed {
+ proj: candle_nn::Conv2d,
+ span: tracing::Span,
+}
+
+impl PatchEmbed {
+ fn new(
+ in_chans: usize,
+ embed_dim: usize,
+ k_size: usize,
+ stride: usize,
+ padding: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let cfg = candle_nn::Conv2dConfig {
+ stride,
+ padding,
+ ..Default::default()
+ };
+ let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
+ Ok(Self { proj, span })
+ }
+}
+
+impl Module for PatchEmbed {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ xs.apply(&self.proj)?.permute((0, 2, 3, 1))
+ }
+}
+
+// A custom op to make add_decomposed_rel_pos faster. Most of the time is spent on the final
+// addition in the case where b = 12, q_h = q_w = 4096, k_h = k_w = 4096
+// (attn.reshape((b, q_h, q_w, k_h, k_w))?
+// + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
+// .reshape((b, q_h * q_w, k_h * k_w))
+// Ideally we would perform this operation in place but this is not supported in candle at the
+// moment. We should also investigate using f16 rather than f32.
+struct Add3(usize, usize, usize, usize, usize);
+impl candle::CustomOp3 for Add3 {
+ fn name(&self) -> &'static str {
+ "add3"
+ }
+
+ fn cpu_fwd(
+ &self,
+ s1: &candle::CpuStorage,
+ l1: &candle::Layout,
+ s2: &candle::CpuStorage,
+ l2: &candle::Layout,
+ s3: &candle::CpuStorage,
+ l3: &candle::Layout,
+ ) -> Result<(candle::CpuStorage, candle::Shape)> {
+ use rayon::prelude::*;
+
+ let Add3(b, q_h, q_w, k_h, k_w) = *self;
+ let s1 = s1.as_slice::<f32>()?;
+ let s1 = match l1.contiguous_offsets() {
+ None => candle::bail!("input1 has to be contiguous"),
+ Some((o1, o2)) => &s1[o1..o2],
+ };
+ let s2 = s2.as_slice::<f32>()?;
+ let s2 = match l2.contiguous_offsets() {
+ None => candle::bail!("input2 has to be contiguous"),
+ Some((o1, o2)) => &s2[o1..o2],
+ };
+ let s3 = s3.as_slice::<f32>()?;
+ let s3 = match l3.contiguous_offsets() {
+ None => candle::bail!("input3 has to be contiguous"),
+ Some((o1, o2)) => &s3[o1..o2],
+ };
+ let mut dst = vec![0f32; b * q_h * q_w * k_h * k_w];
+ dst.par_chunks_exact_mut(k_h * k_w)
+ .enumerate()
+ .for_each(|(b_idx, dst)| {
+ let s1_idx = b_idx * k_h * k_w;
+ let s2_idx = b_idx * k_h;
+ let s3_idx = b_idx * k_w;
+ for h_idx in 0..k_h {
+ let s1_idx = s1_idx + h_idx * k_w;
+ let s2_idx = s2_idx + h_idx;
+ let dst_idx = h_idx * k_w;
+ for w_idx in 0..k_w {
+ let s1_idx = s1_idx + w_idx;
+ let s3_idx = s3_idx + w_idx;
+ let dst_idx = dst_idx + w_idx;
+ dst[dst_idx] = s1[s1_idx] + s2[s2_idx] + s3[s3_idx]
+ }
+ }
+ });
+ let dst = candle::WithDType::to_cpu_storage_owned(dst);
+ Ok((dst, (b, q_h * q_w, k_h * k_w).into()))
+ }
+}
+
+#[derive(Debug)]
+struct Attention {
+ qkv: super::Linear,
+ proj: super::Linear,
+ num_heads: usize,
+ scale: f64,
+ rel_pos_hw: Option<(Tensor, Tensor)>,
+ span: tracing::Span,
+ span_matmul: tracing::Span,
+ span_rel_pos: tracing::Span,
+ span_softmax: tracing::Span,
+}
+
+impl Attention {
+ fn new(
+ dim: usize,
+ num_heads: usize,
+ qkv_bias: bool,
+ use_rel_pos: bool,
+ input_size: (usize, usize),
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "attention");
+ let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
+ let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos");
+ let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
+ let qkv = super::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
+ let proj = super::linear(vb.pp("proj"), dim, dim, true)?;
+ let head_dim = dim / num_heads;
+ let scale = 1. / (head_dim as f64).sqrt();
+ let rel_pos_hw = if use_rel_pos {
+ let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?;
+ let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?;
+ Some((h, w))
+ } else {
+ None
+ };
+ Ok(Self {
+ qkv,
+ proj,
+ num_heads,
+ scale,
+ rel_pos_hw,
+ span,
+ span_matmul,
+ span_rel_pos,
+ span_softmax,
+ })
+ }
+
+ fn add_decomposed_rel_pos(
+ &self,
+ attn: Tensor,
+ q: &Tensor,
+ (q_h, q_w): (usize, usize),
+ (k_h, k_w): (usize, usize),
+ ) -> Result<Tensor> {
+ match &self.rel_pos_hw {
+ Some((rel_pos_h, rel_pos_w)) => {
+ let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?;
+ let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?;
+ let (b, _, dim) = q.dims3()?;
+ let r_q = q.reshape((b, q_h, q_w, dim))?;
+ // rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
+ let rel_h = r_q.matmul(&r_h.broadcast_left(b)?.t()?.contiguous()?)?;
+ // rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
+ let rel_w = r_q
+ .transpose(1, 2)? // -> bwhc
+ .contiguous()?
+ .matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk
+ .transpose(1, 2)?
+ .contiguous()?;
+ if attn.device().is_cpu() {
+ let op = Add3(b, q_h, q_w, k_h, k_w);
+ attn.apply_op3_no_bwd(&rel_h, &rel_w, &op)
+ } else {
+ (attn.reshape((b, q_h, q_w, k_h, k_w))?
+ + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
+ .reshape((b, q_h * q_w, k_h * k_w))
+ }
+ }
+ None => Ok(attn),
+ }
+ }
+}
+
+fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor> {
+ let max_rel_dist = 2 * usize::max(q_size, k_size) - 1;
+ let dev = rel_pos.device();
+ let rel_pos_resized = if rel_pos.dim(0)? != max_rel_dist {
+ todo!("interpolation")
+ } else {
+ rel_pos
+ };
+ let q_coords = Tensor::arange(0u32, q_size as u32, dev)?
+ .reshape((q_size, 1))?
+ .to_dtype(DType::F32)?;
+ let k_coords = Tensor::arange(0u32, k_size as u32, dev)?
+ .reshape((1, k_size))?
+ .to_dtype(DType::F32)?;
+ let q_coords = (q_coords * f64::max(1f64, k_size as f64 / q_size as f64))?;
+ let k_coords = (k_coords * f64::max(1f64, q_size as f64 / k_size as f64))?;
+ let relative_coords = (q_coords.broadcast_sub(&k_coords)?
+ + (k_size as f64 - 1.) * f64::max(1f64, q_size as f64 / k_size as f64))?;
+ let (d1, d2) = relative_coords.dims2()?;
+ let relative_coords = relative_coords.to_dtype(DType::U32)?;
+ rel_pos_resized
+ .index_select(&relative_coords.reshape(d1 * d2)?, 0)?
+ .reshape((d1, d2, ()))
+}
+
+impl Module for Attention {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let (b, h, w, c) = xs.dims4()?;
+ let qkv = self
+ .qkv
+ .forward(&xs.flatten_to(1)?)?
+ .reshape((b, h * w, 3, self.num_heads, c / self.num_heads))?
+ .permute((2, 0, 3, 1, 4))?
+ .reshape((3, b * self.num_heads, h * w, c / self.num_heads))?;
+ let q = qkv.i(0)?;
+ let k = qkv.i(1)?;
+ let v = qkv.i(2)?;
+ let attn = {
+ let _enter = self.span_matmul.enter();
+ (&q * self.scale)?.matmul(&k.t()?)?
+ };
+ let attn = {
+ let _enter = self.span_rel_pos.enter();
+ self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?
+ };
+ let attn = {
+ let _enter = self.span_softmax.enter();
+ candle_nn::ops::softmax_last_dim(&attn)?
+ };
+ let attn = {
+ let _enter = self.span_matmul.enter();
+ attn.matmul(&v)?
+ };
+ let attn = attn
+ .reshape((b, self.num_heads, h, w, c / self.num_heads))?
+ .permute((0, 2, 3, 1, 4))?
+ .reshape((b, h * w, c))?;
+ self.proj.forward(&attn)?.reshape((b, h, w, c))
+ }
+}
+
+#[derive(Debug)]
+struct Block {
+ norm1: LayerNorm,
+ attn: Attention,
+ norm2: LayerNorm,
+ mlp: super::MlpBlock,
+ window_size: usize,
+ span: tracing::Span,
+}
+
+impl Block {
+ fn new(
+ dim: usize,
+ num_heads: usize,
+ qkv_bias: bool,
+ use_rel_pos: bool,
+ window_size: usize,
+ input_size: (usize, usize),
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
+ let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
+ let input_size_attn = if window_size == 0 {
+ input_size
+ } else {
+ (window_size, window_size)
+ };
+ let attn = Attention::new(
+ dim,
+ num_heads,
+ qkv_bias,
+ use_rel_pos,
+ input_size_attn,
+ vb.pp("attn"),
+ )?;
+ let mlp = super::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "ie-block");
+ Ok(Self {
+ norm1,
+ attn,
+ norm2,
+ mlp,
+ window_size,
+ span,
+ })
+ }
+}
+
+fn window_partition(xs: Tensor, window_size: usize) -> Result<(Tensor, (usize, usize))> {
+ let (b, h, w, c) = xs.dims4()?;
+ let pad_h = (window_size - h % window_size) % window_size;
+ let pad_w = (window_size - w % window_size) % window_size;
+ let xs = if pad_h > 0 {
+ xs.pad_with_zeros(1, 0, pad_h)?
+ } else {
+ xs
+ };
+ let xs = if pad_w > 0 {
+ xs.pad_with_zeros(2, 0, pad_w)?
+ } else {
+ xs
+ };
+ let (h_p, w_p) = (h + pad_h, w + pad_w);
+ let windows = xs
+ .reshape((
+ b,
+ h_p / window_size,
+ window_size,
+ w_p / window_size,
+ window_size,
+ c,
+ ))?
+ .transpose(2, 3)?
+ .contiguous()?
+ .flatten_to(2)?;
+ Ok((windows, (h_p, w_p)))
+}
+
+fn window_unpartition(
+ windows: Tensor,
+ window_size: usize,
+ (h_p, w_p): (usize, usize),
+ (h, w): (usize, usize),
+) -> Result<Tensor> {
+ let b = windows.dim(0)? / (h_p * w_p / window_size / window_size);
+ let xs = windows
+ .reshape((
+ b,
+ h_p / window_size,
+ w_p / window_size,
+ window_size,
+ window_size,
+ windows.elem_count() / b / h_p / w_p,
+ ))?
+ .transpose(2, 3)?
+ .contiguous()?
+ .reshape((b, h_p, w_p, ()))?;
+ let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs };
+ let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs };
+ Ok(xs)
+}
+
+impl Module for Block {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let shortcut = xs;
+ let xs = self.norm1.forward(xs)?;
+ let hw = (xs.dim(1)?, xs.dim(2)?);
+ let (xs, pad_hw) = if self.window_size > 0 {
+ window_partition(xs, self.window_size)?
+ } else {
+ (xs, (0, 0))
+ };
+ let xs = self.attn.forward(&xs)?;
+ let xs = if self.window_size > 0 {
+ window_unpartition(xs, self.window_size, pad_hw, hw)?
+ } else {
+ xs
+ };
+ let xs = (xs + shortcut)?;
+ &xs + xs.apply(&self.norm2)?.apply(&self.mlp)?
+ }
+}
+
+#[derive(Debug)]
+pub struct ImageEncoderViT {
+ patch_embed: PatchEmbed,
+ blocks: Vec<Block>,
+ neck_conv1: candle_nn::Conv2d,
+ neck_ln1: super::LayerNorm2d,
+ neck_conv2: candle_nn::Conv2d,
+ neck_ln2: super::LayerNorm2d,
+ pos_embed: Option<Tensor>,
+ span: tracing::Span,
+}
+
+impl ImageEncoderViT {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ img_size: usize,
+ patch_size: usize,
+ in_chans: usize,
+ embed_dim: usize,
+ depth: usize,
+ num_heads: usize,
+ out_chans: usize,
+ qkv_bias: bool,
+ use_rel_pos: bool,
+ use_abs_pos: bool,
+ window_size: usize,
+ global_attn_indexes: &[usize],
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let patch_embed = PatchEmbed::new(
+ in_chans,
+ embed_dim,
+ patch_size,
+ patch_size,
+ 0,
+ vb.pp("patch_embed"),
+ )?;
+ let mut blocks = Vec::with_capacity(depth);
+ let vb_b = vb.pp("blocks");
+ for i in 0..depth {
+ let window_size = if global_attn_indexes.contains(&i) {
+ 0
+ } else {
+ window_size
+ };
+ let block = Block::new(
+ embed_dim,
+ num_heads,
+ qkv_bias,
+ use_rel_pos,
+ window_size,
+ (img_size / patch_size, img_size / patch_size),
+ vb_b.pp(i),
+ )?;
+ blocks.push(block)
+ }
+ let neck_conv1 = candle_nn::conv2d_no_bias(
+ embed_dim,
+ out_chans,
+ 1,
+ Default::default(),
+ vb.pp("neck.0"),
+ )?;
+ let neck_ln1 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?;
+ let cfg = candle_nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?;
+ let neck_ln2 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?;
+ let pos_embed = if use_abs_pos {
+ let p = vb.get(
+ (1, img_size / patch_size, img_size / patch_size, embed_dim),
+ "pos_embed",
+ )?;
+ Some(p)
+ } else {
+ None
+ };
+ let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit");
+ Ok(Self {
+ patch_embed,
+ blocks,
+ neck_conv1,
+ neck_ln1,
+ neck_conv2,
+ neck_ln2,
+ pos_embed,
+ span,
+ })
+ }
+}
+
+impl Module for ImageEncoderViT {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let xs = self.patch_embed.forward(xs)?;
+ let mut xs = match &self.pos_embed {
+ Some(pos_embed) => (xs + pos_embed)?,
+ None => xs,
+ };
+ for block in self.blocks.iter() {
+ xs = block.forward(&xs)?
+ }
+ xs.permute((0, 3, 1, 2))?
+ .apply(&self.neck_conv1)?
+ .apply(&self.neck_ln1)?
+ .apply(&self.neck_conv2)?
+ .apply(&self.neck_ln2)
+ }
+}
diff --git a/candle-transformers/src/models/segment_anything/mask_decoder.rs b/candle-transformers/src/models/segment_anything/mask_decoder.rs
new file mode 100644
index 00000000..2a91cd44
--- /dev/null
+++ b/candle-transformers/src/models/segment_anything/mask_decoder.rs
@@ -0,0 +1,239 @@
+use candle::{IndexOp, Result, Tensor};
+use candle_nn::{Module, VarBuilder};
+
+use super::transformer::TwoWayTransformer;
+
+#[derive(Debug)]
+struct MlpMaskDecoder {
+ layers: Vec<super::Linear>,
+ sigmoid_output: bool,
+ span: tracing::Span,
+}
+
+impl MlpMaskDecoder {
+ fn new(
+ input_dim: usize,
+ hidden_dim: usize,
+ output_dim: usize,
+ num_layers: usize,
+ sigmoid_output: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let mut layers = Vec::with_capacity(num_layers);
+ let vb = vb.pp("layers");
+ for i in 0..num_layers {
+ let in_dim = if i == 0 { input_dim } else { hidden_dim };
+ let out_dim = if i + 1 == num_layers {
+ output_dim
+ } else {
+ hidden_dim
+ };
+ let layer = super::linear(vb.pp(i), in_dim, out_dim, true)?;
+ layers.push(layer)
+ }
+ let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder");
+ Ok(Self {
+ layers,
+ sigmoid_output,
+ span,
+ })
+ }
+}
+
+impl Module for MlpMaskDecoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let mut xs = xs.clone();
+ for (i, layer) in self.layers.iter().enumerate() {
+ xs = layer.forward(&xs)?;
+ if i + 1 < self.layers.len() {
+ xs = xs.relu()?
+ }
+ }
+ if self.sigmoid_output {
+ candle_nn::ops::sigmoid(&xs)
+ } else {
+ Ok(xs)
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct MaskDecoder {
+ iou_token: candle_nn::Embedding,
+ mask_tokens: candle_nn::Embedding,
+ iou_prediction_head: MlpMaskDecoder,
+ output_upscaling_conv1: candle_nn::ConvTranspose2d,
+ output_upscaling_ln: super::LayerNorm2d,
+ output_upscaling_conv2: candle_nn::ConvTranspose2d,
+ num_mask_tokens: usize,
+ output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
+ transformer: TwoWayTransformer,
+ span: tracing::Span,
+}
+
+impl MaskDecoder {
+ pub fn new(
+ transformer_dim: usize,
+ num_multimask_outputs: usize,
+ iou_head_depth: usize,
+ iou_head_hidden_dim: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let num_mask_tokens = num_multimask_outputs + 1;
+ let iou_prediction_head = MlpMaskDecoder::new(
+ transformer_dim,
+ iou_head_hidden_dim,
+ num_mask_tokens,
+ iou_head_depth,
+ false,
+ vb.pp("iou_prediction_head"),
+ )?;
+ let iou_token = candle_nn::embedding(1, transformer_dim, vb.pp("iou_token"))?;
+ let mask_tokens =
+ candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?;
+ let cfg = candle_nn::ConvTranspose2dConfig {
+ stride: 2,
+ ..Default::default()
+ };
+ let output_upscaling_conv1 = candle_nn::conv_transpose2d(
+ transformer_dim,
+ transformer_dim / 4,
+ 2,
+ cfg,
+ vb.pp("output_upscaling.0"),
+ )?;
+ let output_upscaling_ln =
+ super::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
+ let output_upscaling_conv2 = candle_nn::conv_transpose2d(
+ transformer_dim / 4,
+ transformer_dim / 8,
+ 2,
+ cfg,
+ vb.pp("output_upscaling.3"),
+ )?;
+ let mut output_hypernetworks_mlps = Vec::with_capacity(num_mask_tokens);
+ let vb_o = vb.pp("output_hypernetworks_mlps");
+ for i in 0..num_mask_tokens {
+ let mlp = MlpMaskDecoder::new(
+ transformer_dim,
+ transformer_dim,
+ transformer_dim / 8,
+ 3,
+ false,
+ vb_o.pp(i),
+ )?;
+ output_hypernetworks_mlps.push(mlp)
+ }
+ let transformer = TwoWayTransformer::new(
+ /* depth */ 2,
+ /* embedding_dim */ transformer_dim,
+ /* num_heads */ 8,
+ /* mlp_dim */ 2048,
+ vb.pp("transformer"),
+ )?;
+ let span = tracing::span!(tracing::Level::TRACE, "mask-decoder");
+ Ok(Self {
+ iou_token,
+ mask_tokens,
+ iou_prediction_head,
+ output_upscaling_conv1,
+ output_upscaling_ln,
+ output_upscaling_conv2,
+ num_mask_tokens,
+ output_hypernetworks_mlps,
+ transformer,
+ span,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ image_embeddings: &Tensor,
+ image_pe: &Tensor,
+ sparse_prompt_embeddings: &Tensor,
+ dense_prompt_embeddings: &Tensor,
+ multimask_output: bool,
+ ) -> Result<(Tensor, Tensor)> {
+ let _enter = self.span.enter();
+ let (masks, iou_pred) = self.predict_masks(
+ image_embeddings,
+ image_pe,
+ sparse_prompt_embeddings,
+ dense_prompt_embeddings,
+ )?;
+ let masks = if multimask_output {
+ masks.i((.., 1..))?
+ } else {
+ masks.i((.., 0..1))?
+ };
+ let iou_pred = if multimask_output {
+ iou_pred.i((.., 1..))?
+ } else {
+ iou_pred.i((.., 0..1))?
+ };
+ Ok((masks, iou_pred))
+ }
+
+ fn predict_masks(
+ &self,
+ image_embeddings: &Tensor,
+ image_pe: &Tensor,
+ sparse_prompt_embeddings: &Tensor,
+ dense_prompt_embeddings: &Tensor,
+ ) -> Result<(Tensor, Tensor)> {
+ // Concatenate ouput tokens.
+ let output_tokens = Tensor::cat(
+ &[self.iou_token.embeddings(), self.mask_tokens.embeddings()],
+ 0,
+ )?;
+ let (d1, d2) = output_tokens.dims2()?;
+ let output_tokens =
+ output_tokens
+ .unsqueeze(0)?
+ .expand((sparse_prompt_embeddings.dim(0)?, d1, d2))?;
+ let tokens = Tensor::cat(&[&output_tokens, sparse_prompt_embeddings], 1)?;
+
+ // Expand per-image data in batch direction to be per mask
+ let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
+ let src = src.broadcast_add(dense_prompt_embeddings)?;
+ let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
+ let (b, c, h, w) = src.dims4()?;
+
+ // Run the transformer
+ let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?;
+ let iou_token_out = hs.i((.., 0))?;
+ let mask_tokens_out = hs.i((.., 1..1 + self.num_mask_tokens))?;
+
+ // Upscale mask embeddings and predict masks using the masks tokens.
+ let src = src.transpose(1, 2)?.reshape((b, c, h, w))?;
+ let upscaled_embedding = self
+ .output_upscaling_conv1
+ .forward(&src)?
+ .apply(&self.output_upscaling_ln)?
+ .gelu()?
+ .apply(&self.output_upscaling_conv2)?
+ .gelu()?;
+ let mut hyper_in_list = Vec::with_capacity(self.num_mask_tokens);
+ for (i, mlp) in self.output_hypernetworks_mlps.iter().enumerate() {
+ let h = mlp.forward(&mask_tokens_out.i((.., i))?)?;
+ hyper_in_list.push(h)
+ }
+ let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?.contiguous()?;
+ let (b, c, h, w) = upscaled_embedding.dims4()?;
+ let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?;
+ let masks = masks.reshape((b, (), h, w))?;
+
+ // Generate mask quality predictions.
+ let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?;
+ Ok((masks, iou_pred))
+ }
+}
+
+// Equivalent to torch.repeat_interleave
+fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
+ let img = img.unsqueeze(dim + 1)?;
+ let mut dims = img.dims().to_vec();
+ dims[dim + 1] = repeats;
+ img.broadcast_as(dims)?.flatten(dim, dim + 1)
+}
diff --git a/candle-transformers/src/models/segment_anything/mod.rs b/candle-transformers/src/models/segment_anything/mod.rs
new file mode 100644
index 00000000..c29db70a
--- /dev/null
+++ b/candle-transformers/src/models/segment_anything/mod.rs
@@ -0,0 +1,100 @@
+use candle::{Result, Tensor};
+use candle_nn::{Module, VarBuilder};
+
+pub mod image_encoder;
+pub mod mask_decoder;
+pub mod prompt_encoder;
+pub mod sam;
+pub mod tiny_vit;
+pub mod transformer;
+
+pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
+ let inner = if bias {
+ candle_nn::linear(in_dim, out_dim, vb)?
+ } else {
+ candle_nn::linear_no_bias(in_dim, out_dim, vb)?
+ };
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ Ok(Linear { inner, span })
+}
+
+#[derive(Debug)]
+pub struct LayerNorm2d {
+ weight: Tensor,
+ bias: Tensor,
+ num_channels: usize,
+ eps: f64,
+}
+
+impl LayerNorm2d {
+ pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
+ let weight = vb.get(num_channels, "weight")?;
+ let bias = vb.get(num_channels, "bias")?;
+ Ok(Self {
+ weight,
+ bias,
+ num_channels,
+ eps,
+ })
+ }
+}
+
+impl Module for LayerNorm2d {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let u = xs.mean_keepdim(1)?;
+ let xs = xs.broadcast_sub(&u)?;
+ let s = xs.sqr()?.mean_keepdim(1)?;
+ let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
+ xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
+ .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
+ }
+}
+
+#[derive(Debug)]
+pub struct MlpBlock {
+ lin1: Linear,
+ lin2: Linear,
+ activation: candle_nn::Activation,
+ span: tracing::Span,
+}
+
+impl MlpBlock {
+ pub fn new(
+ embedding_dim: usize,
+ mlp_dim: usize,
+ activation: candle_nn::Activation,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
+ let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
+ let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
+ Ok(Self {
+ lin1,
+ lin2,
+ activation,
+ span,
+ })
+ }
+}
+
+impl Module for MlpBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ xs.apply(&self.lin1)?
+ .apply(&self.activation)?
+ .apply(&self.lin2)
+ }
+}
+
+#[derive(Debug)]
+pub struct Linear {
+ inner: candle_nn::Linear,
+ span: tracing::Span,
+}
+
+impl Module for Linear {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
diff --git a/candle-transformers/src/models/segment_anything/prompt_encoder.rs b/candle-transformers/src/models/segment_anything/prompt_encoder.rs
new file mode 100644
index 00000000..9d0074b1
--- /dev/null
+++ b/candle-transformers/src/models/segment_anything/prompt_encoder.rs
@@ -0,0 +1,239 @@
+use candle::{DType, IndexOp, Result, Tensor, D};
+use candle_nn::VarBuilder;
+
+#[derive(Debug)]
+struct PostionEmbeddingRandom {
+ positional_encoding_gaussian_matrix: Tensor,
+}
+
+impl PostionEmbeddingRandom {
+ fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> {
+ let positional_encoding_gaussian_matrix =
+ vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?;
+ Ok(Self {
+ positional_encoding_gaussian_matrix,
+ })
+ }
+
+ fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> {
+ let coords = coords.affine(2., -1.)?;
+ let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?;
+ let coords = (coords * (2. * std::f64::consts::PI))?;
+ Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1)
+ }
+
+ fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
+ let device = self.positional_encoding_gaussian_matrix.device();
+ let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
+ let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
+ let x_embed = (x_embed / w as f64)?
+ .reshape((1, ()))?
+ .broadcast_as((h, w))?;
+ let y_embed = (y_embed / h as f64)?
+ .reshape(((), 1))?
+ .broadcast_as((h, w))?;
+ let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?;
+ self.pe_encoding(&coords)?.permute((2, 0, 1))
+ }
+
+ fn forward_with_coords(
+ &self,
+ coords_input: &Tensor,
+ image_size: (usize, usize),
+ ) -> Result<Tensor> {
+ let coords0 = (coords_input.narrow(D::Minus1, 0, 1)? / image_size.1 as f64)?;
+ let coords1 = (coords_input.narrow(D::Minus1, 1, 1)? / image_size.0 as f64)?;
+ let c = coords_input.dim(D::Minus1)?;
+ let coords_rest = coords_input.narrow(D::Minus1, 2, c - 2)?;
+ let coords = Tensor::cat(&[&coords0, &coords1, &coords_rest], D::Minus1)?;
+ self.pe_encoding(&coords)
+ }
+}
+
+#[derive(Debug)]
+pub struct PromptEncoder {
+ pe_layer: PostionEmbeddingRandom,
+ point_embeddings: Vec<candle_nn::Embedding>,
+ not_a_point_embed: candle_nn::Embedding,
+ mask_downscaling_conv1: candle_nn::Conv2d,
+ mask_downscaling_ln1: super::LayerNorm2d,
+ mask_downscaling_conv2: candle_nn::Conv2d,
+ mask_downscaling_ln2: super::LayerNorm2d,
+ mask_downscaling_conv3: candle_nn::Conv2d,
+ no_mask_embed: candle_nn::Embedding,
+ image_embedding_size: (usize, usize),
+ input_image_size: (usize, usize),
+ embed_dim: usize,
+ span: tracing::Span,
+}
+
+impl PromptEncoder {
+ pub fn new(
+ embed_dim: usize,
+ image_embedding_size: (usize, usize),
+ input_image_size: (usize, usize),
+ mask_in_chans: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let num_points_embeddings = 4;
+ let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?;
+ let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?;
+ let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?;
+ let cfg = candle_nn::Conv2dConfig {
+ stride: 2,
+ ..Default::default()
+ };
+ let mask_downscaling_conv1 =
+ candle_nn::conv2d(1, mask_in_chans / 4, 2, cfg, vb.pp("mask_downscaling.0"))?;
+ let mask_downscaling_conv2 = candle_nn::conv2d(
+ mask_in_chans / 4,
+ mask_in_chans,
+ 2,
+ cfg,
+ vb.pp("mask_downscaling.3"),
+ )?;
+ let mask_downscaling_conv3 = candle_nn::conv2d(
+ mask_in_chans,
+ embed_dim,
+ 1,
+ Default::default(),
+ vb.pp("mask_downscaling.6"),
+ )?;
+ let mask_downscaling_ln1 =
+ super::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
+ let mask_downscaling_ln2 =
+ super::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
+ let mut point_embeddings = Vec::with_capacity(num_points_embeddings);
+ let vb_e = vb.pp("point_embeddings");
+ for i in 0..num_points_embeddings {
+ let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?;
+ point_embeddings.push(emb)
+ }
+ let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder");
+ Ok(Self {
+ pe_layer,
+ point_embeddings,
+ not_a_point_embed,
+ mask_downscaling_conv1,
+ mask_downscaling_ln1,
+ mask_downscaling_conv2,
+ mask_downscaling_ln2,
+ mask_downscaling_conv3,
+ no_mask_embed,
+ image_embedding_size,
+ input_image_size,
+ embed_dim,
+ span,
+ })
+ }
+
+ pub fn get_dense_pe(&self) -> Result<Tensor> {
+ self.pe_layer
+ .forward(self.image_embedding_size.0, self.image_embedding_size.1)?
+ .unsqueeze(0)
+ }
+
+ fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> {
+ masks
+ .apply(&self.mask_downscaling_conv1)?
+ .apply(&self.mask_downscaling_ln1)?
+ .gelu()?
+ .apply(&self.mask_downscaling_conv2)?
+ .apply(&self.mask_downscaling_ln2)?
+ .gelu()?
+ .apply(&self.mask_downscaling_conv3)
+ }
+
+ fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> {
+ let points = (points + 0.5)?;
+ let dev = points.device();
+ let (points, labels) = if pad {
+ let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?;
+ let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?;
+ let points = Tensor::cat(&[&points, &padding_point], 1)?;
+ let labels = Tensor::cat(&[labels, &padding_label], 1)?;
+ (points, labels)
+ } else {
+ (points, labels.clone())
+ };
+ let point_embedding = self
+ .pe_layer
+ .forward_with_coords(&points, self.input_image_size)?;
+ let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
+ let zeros = point_embedding.zeros_like()?;
+ let point_embedding = labels.lt(0f32)?.where_cond(
+ &self
+ .not_a_point_embed
+ .embeddings()
+ .broadcast_as(zeros.shape())?,
+ &point_embedding,
+ )?;
+ let labels0 = labels.eq(0f32)?.where_cond(
+ &self.point_embeddings[0]
+ .embeddings()
+ .broadcast_as(zeros.shape())?,
+ &zeros,
+ )?;
+ let point_embedding = (point_embedding + labels0)?;
+ let labels1 = labels.eq(1f32)?.where_cond(
+ &self.point_embeddings[1]
+ .embeddings()
+ .broadcast_as(zeros.shape())?,
+ &zeros,
+ )?;
+ let point_embedding = (point_embedding + labels1)?;
+ Ok(point_embedding)
+ }
+
+ fn embed_boxes(&self, boxes: &Tensor) -> Result<Tensor> {
+ let boxes = (boxes + 0.5)?;
+ let coords = boxes.reshape(((), 2, 2))?;
+ let corner_embedding = self
+ .pe_layer
+ .forward_with_coords(&coords, self.input_image_size)?;
+ let ce1 = corner_embedding.i((.., 0))?;
+ let ce2 = corner_embedding.i((.., 1))?;
+ let ce1 = (ce1 + self.point_embeddings[2].embeddings())?;
+ let ce2 = (ce2 + self.point_embeddings[3].embeddings())?;
+ Tensor::cat(&[&ce1, &ce2], 1)
+ }
+
+ pub fn forward(
+ &self,
+ points: Option<(&Tensor, &Tensor)>,
+ boxes: Option<&Tensor>,
+ masks: Option<&Tensor>,
+ ) -> Result<(Tensor, Tensor)> {
+ let _enter = self.span.enter();
+ let se_points = match points {
+ Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?),
+ None => None,
+ };
+ let se_boxes = match boxes {
+ Some(boxes) => Some(self.embed_boxes(boxes)?),
+ None => None,
+ };
+ let sparse_embeddings = match (se_points, se_boxes) {
+ (Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?,
+ (Some(se_points), None) => se_points,
+ (None, Some(se_boxes)) => se_boxes,
+ (None, None) => {
+ Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)?
+ }
+ };
+
+ let dense_embeddings = match masks {
+ None => {
+ let emb = self.no_mask_embed.embeddings();
+ emb.reshape((1, (), 1, 1))?.expand((
+ 1,
+ emb.elem_count(),
+ self.image_embedding_size.0,
+ self.image_embedding_size.1,
+ ))?
+ }
+ Some(masks) => self.embed_masks(masks)?,
+ };
+ Ok((sparse_embeddings, dense_embeddings))
+ }
+}
diff --git a/candle-transformers/src/models/segment_anything/sam.rs b/candle-transformers/src/models/segment_anything/sam.rs
new file mode 100644
index 00000000..07e9a759
--- /dev/null
+++ b/candle-transformers/src/models/segment_anything/sam.rs
@@ -0,0 +1,433 @@
+use candle::{DType, IndexOp, Result, Tensor};
+use candle_nn::{Module, VarBuilder};
+
+use super::image_encoder::ImageEncoderViT;
+use super::mask_decoder::MaskDecoder;
+use super::prompt_encoder::PromptEncoder;
+use super::tiny_vit::{tiny_vit_5m, TinyViT};
+
+const PROMPT_EMBED_DIM: usize = 256;
+pub const IMAGE_SIZE: usize = 1024;
+const VIT_PATCH_SIZE: usize = 16;
+const PRED_IOU_THRESH: f32 = 0.88;
+const STABILITY_SCORE_OFFSET: f32 = 1.0;
+const STABILITY_SCORE_THRESHOLD: f32 = 0.95;
+const MODEL_MASK_THRESHOLD: f32 = 0.0;
+const CROP_NMS_THRESH: f32 = 0.7;
+
+#[derive(Debug)]
+enum ImageEncoder {
+ Original(ImageEncoderViT),
+ TinyViT(TinyViT),
+}
+
+impl Module for ImageEncoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ match self {
+ Self::Original(vit) => vit.forward(xs),
+ Self::TinyViT(vit) => vit.forward(xs),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Sam {
+ image_encoder: ImageEncoder,
+ prompt_encoder: PromptEncoder,
+ mask_decoder: MaskDecoder,
+ pixel_mean: Tensor,
+ pixel_std: Tensor,
+}
+
+impl Sam {
+ pub fn new(
+ encoder_embed_dim: usize,
+ encoder_depth: usize,
+ encoder_num_heads: usize,
+ encoder_global_attn_indexes: &[usize],
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
+
+ let image_encoder = ImageEncoderViT::new(
+ IMAGE_SIZE,
+ VIT_PATCH_SIZE,
+ 3,
+ encoder_embed_dim,
+ encoder_depth,
+ encoder_num_heads,
+ PROMPT_EMBED_DIM,
+ /* qkv_bias */ true,
+ /* use_rel_pos */ true,
+ /* use_abs_pos */ true,
+ /* window_size */ 14,
+ /* global_attn_indexes */ encoder_global_attn_indexes,
+ vb.pp("image_encoder"),
+ )?;
+ let prompt_encoder = PromptEncoder::new(
+ PROMPT_EMBED_DIM,
+ (image_embedding_size, image_embedding_size),
+ (IMAGE_SIZE, IMAGE_SIZE),
+ 16,
+ vb.pp("prompt_encoder"),
+ )?;
+ let mask_decoder = MaskDecoder::new(
+ PROMPT_EMBED_DIM,
+ /* num_multitask_outputs */ 3,
+ /* iou_head_depth */ 3,
+ /* iou_head_hidden_dim */ 256,
+ vb.pp("mask_decoder"),
+ )?;
+ let pixel_mean =
+ Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
+ let pixel_std =
+ Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
+ Ok(Self {
+ image_encoder: ImageEncoder::Original(image_encoder),
+ prompt_encoder,
+ mask_decoder,
+ pixel_std,
+ pixel_mean,
+ })
+ }
+
+ pub fn new_tiny(vb: VarBuilder) -> Result<Self> {
+ let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE;
+
+ let image_encoder = tiny_vit_5m(vb.pp("image_encoder"))?;
+ let prompt_encoder = PromptEncoder::new(
+ PROMPT_EMBED_DIM,
+ (image_embedding_size, image_embedding_size),
+ (IMAGE_SIZE, IMAGE_SIZE),
+ 16,
+ vb.pp("prompt_encoder"),
+ )?;
+ let mask_decoder = MaskDecoder::new(
+ PROMPT_EMBED_DIM,
+ /* num_multitask_outputs */ 3,
+ /* iou_head_depth */ 3,
+ /* iou_head_hidden_dim */ 256,
+ vb.pp("mask_decoder"),
+ )?;
+ let pixel_mean =
+ Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?;
+ let pixel_std =
+ Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
+ Ok(Self {
+ image_encoder: ImageEncoder::TinyViT(image_encoder),
+ prompt_encoder,
+ mask_decoder,
+ pixel_std,
+ pixel_mean,
+ })
+ }
+
+ pub fn embeddings(&self, img: &Tensor) -> Result<Tensor> {
+ let img = self.preprocess(img)?.unsqueeze(0)?;
+ self.image_encoder.forward(&img)
+ }
+
+ pub fn forward(
+ &self,
+ img: &Tensor,
+ point: Option<(f64, f64)>,
+ multimask_output: bool,
+ ) -> Result<(Tensor, Tensor)> {
+ let (_c, original_h, original_w) = img.dims3()?;
+ let img = self.preprocess(img)?.unsqueeze(0)?;
+ let img_embeddings = self.image_encoder.forward(&img)?;
+ let (low_res_mask, iou) = self.forward_for_embeddings(
+ &img_embeddings,
+ original_h,
+ original_w,
+ point,
+ multimask_output,
+ )?;
+ let mask = low_res_mask
+ .upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)?
+ .get(0)?
+ .i((.., ..original_h, ..original_w))?;
+ Ok((mask, iou))
+ }
+
+ pub fn forward_for_embeddings(
+ &self,
+ img_embeddings: &Tensor,
+ original_h: usize,
+ original_w: usize,
+ point: Option<(f64, f64)>,
+ multimask_output: bool,
+ ) -> Result<(Tensor, Tensor)> {
+ let image_pe = self.prompt_encoder.get_dense_pe()?;
+ let points = match point {
+ None => None,
+ Some((x, y)) => {
+ let points = Tensor::new(
+ &[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
+ img_embeddings.device(),
+ )?;
+ let labels = Tensor::ones((1, 1), DType::F32, img_embeddings.device())?;
+ Some((points, labels))
+ }
+ };
+ let points = points.as_ref().map(|(x, y)| (x, y));
+ let (sparse_prompt_embeddings, dense_prompt_embeddings) =
+ self.prompt_encoder.forward(points, None, None)?;
+ self.mask_decoder.forward(
+ img_embeddings,
+ &image_pe,
+ &sparse_prompt_embeddings,
+ &dense_prompt_embeddings,
+ multimask_output,
+ )
+ }
+
+ pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {
+ let img = img
+ .broadcast_mul(&self.pixel_std)?
+ .broadcast_add(&self.pixel_mean)?;
+ img.maximum(&img.zeros_like()?)?
+ .minimum(&(img.ones_like()? * 255.)?)
+ }
+
+ pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
+ let (_c, h, w) = img.dims3()?;
+ let img = img
+ .to_dtype(DType::F32)?
+ .broadcast_sub(&self.pixel_mean)?
+ .broadcast_div(&self.pixel_std)?;
+ if h > IMAGE_SIZE || w > IMAGE_SIZE {
+ candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}")
+ }
+ let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?;
+ img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
+ }
+
+ fn process_crop(
+ &self,
+ img: &Tensor,
+ cb: CropBox,
+ point_grids: &[(f64, f64)],
+ ) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> {
+ // Crop the image and calculate embeddings.
+ let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
+ let img = self.preprocess(&img)?.unsqueeze(0)?;
+ let img_embeddings = self.image_encoder.forward(&img)?;
+
+ let crop_w = cb.x1 - cb.x0;
+ let crop_h = cb.y1 - cb.y0;
+
+ // Generate masks for this crop.
+ let image_pe = self.prompt_encoder.get_dense_pe()?;
+ let points = point_grids
+ .iter()
+ .map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
+ .collect::<Vec<_>>();
+
+ let mut bboxes = Vec::new();
+ for points in points.chunks(64) {
+ // Run the model on this batch.
+ let points_len = points.len();
+ let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
+ let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
+ let (sparse_prompt_embeddings, dense_prompt_embeddings) =
+ self.prompt_encoder
+ .forward(Some((&in_points, &in_labels)), None, None)?;
+
+ let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
+ &img_embeddings,
+ &image_pe,
+ &sparse_prompt_embeddings,
+ &dense_prompt_embeddings,
+ /* multimask_output */ true,
+ )?;
+ let low_res_mask = low_res_mask.flatten(0, 1)?;
+ let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?;
+ let dev = low_res_mask.device();
+
+ for (i, iou) in iou_predictions.iter().enumerate() {
+ // Filter by predicted IoU.
+ if *iou < PRED_IOU_THRESH {
+ continue;
+ }
+ let low_res_mask = low_res_mask.get(i)?;
+
+ // Calculate stability score.
+ let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)?
+ .broadcast_as(low_res_mask.shape())?;
+ let intersections = low_res_mask
+ .ge(&bound)?
+ .to_dtype(DType::F32)?
+ .sum_all()?
+ .to_vec0::<f32>()?;
+ let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)?
+ .broadcast_as(low_res_mask.shape())?;
+ let unions = low_res_mask
+ .ge(&bound)?
+ .to_dtype(DType::F32)?
+ .sum_all()?
+ .to_vec0::<f32>()?;
+ let stability_score = intersections / unions;
+ if stability_score < STABILITY_SCORE_THRESHOLD {
+ continue;
+ }
+
+ // Threshold masks and calculate boxes.
+ let low_res_mask = low_res_mask
+ .ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)?
+ .to_dtype(DType::U32)?;
+ let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?;
+ let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?;
+ let min_max_x = min_max_indexes(&low_res_mask_per_x);
+ let min_max_y = min_max_indexes(&low_res_mask_per_y);
+ if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) {
+ let bbox = crate::object_detection::Bbox {
+ xmin: x0 as f32,
+ ymin: y0 as f32,
+ xmax: x1 as f32,
+ ymax: y1 as f32,
+ confidence: *iou,
+ data: low_res_mask,
+ };
+ bboxes.push(bbox);
+ }
+ // TODO:
+ // Filter boxes that touch crop boundaries
+ // Compress to RLE.
+ }
+ }
+
+ let mut bboxes = vec![bboxes];
+ // Remove duplicates within this crop.
+ crate::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);
+
+ // TODO: Return to the original image frame.
+ Ok(bboxes.remove(0))
+ }
+
+ pub fn generate_masks(
+ &self,
+ img: &Tensor,
+ points_per_side: usize,
+ crop_n_layer: usize,
+ crop_overlap_ratio: f64,
+ crop_n_points_downscale_factor: usize,
+ ) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> {
+ let (_c, h, w) = img.dims3()?;
+ let point_grids = build_all_layer_point_grids(
+ points_per_side,
+ crop_n_layer,
+ crop_n_points_downscale_factor,
+ );
+ let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio);
+ let mut bboxes = Vec::new();
+ for crop_box in crop_boxes.into_iter() {
+ let layer_idx = crop_box.layer_idx;
+ let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?;
+ bboxes.extend(b)
+ }
+ // TODO: remove duplicates
+ Ok(bboxes)
+ }
+}
+
+// Return the first and last indexes i for which values[i] > 0
+fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> {
+ let (mut min_i, mut max_i) = (usize::MAX, usize::MIN);
+ for (i, &s) in values.iter().enumerate() {
+ if s == 0 {
+ continue;
+ }
+ min_i = usize::min(i, min_i);
+ max_i = usize::max(i, max_i);
+ }
+ if max_i < min_i {
+ None
+ } else {
+ Some((min_i, max_i))
+ }
+}
+
+#[derive(Debug)]
+struct CropBox {
+ x0: usize,
+ y0: usize,
+ x1: usize,
+ y1: usize,
+ layer_idx: usize,
+}
+
+impl CropBox {
+ fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self {
+ Self {
+ x0,
+ y0,
+ x1,
+ y1,
+ layer_idx,
+ }
+ }
+}
+
+fn generate_crop_boxes(
+ (im_h, im_w): (usize, usize),
+ n_layers: usize,
+ overlap_ratio: f64,
+) -> Vec<CropBox> {
+ fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize {
+ f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize
+ }
+
+ let short_side = usize::min(im_h, im_w);
+
+ let mut crop_boxes = Vec::new();
+
+ // Original image.
+ crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0));
+
+ for layer_idx in 1..=n_layers {
+ let n_crops_per_side = 1 << layer_idx;
+ let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize;
+ let crop_w = crop_len(im_w, n_crops_per_side, overlap);
+ let crop_h = crop_len(im_w, n_crops_per_side, overlap);
+
+ for i_x in 0..n_crops_per_side {
+ let x0 = (crop_w - overlap) * i_x;
+ for i_y in 0..n_crops_per_side {
+ let y0 = (crop_h - overlap) * i_y;
+ let x1 = usize::min(im_w, x0 + crop_w);
+ let y1 = usize::min(im_h, y0 + crop_h);
+ crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx));
+ }
+ }
+ }
+
+ crop_boxes
+}
+
+// Generates a 2D grid of points evenly spaced in [0,1]x[0,1].
+fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> {
+ let offset = 1f64 / (2 * n_per_side) as f64;
+ let mut points = Vec::with_capacity(n_per_side * n_per_side);
+ for i_x in 0..n_per_side {
+ let x = offset + i_x as f64 / n_per_side as f64;
+ for i_y in 0..n_per_side {
+ let y = offset + i_y as f64 / n_per_side as f64;
+ points.push((x, y))
+ }
+ }
+ points
+}
+
+fn build_all_layer_point_grids(
+ n_per_side: usize,
+ n_layers: usize,
+ scale_per_layer: usize,
+) -> Vec<Vec<(f64, f64)>> {
+ let mut points_by_layer = Vec::with_capacity(n_layers + 1);
+ for i in 0..=n_layers {
+ let n_points = n_per_side / scale_per_layer.pow(i as u32);
+ points_by_layer.push(build_point_grid(n_points))
+ }
+ points_by_layer
+}
diff --git a/candle-transformers/src/models/segment_anything/tiny_vit.rs b/candle-transformers/src/models/segment_anything/tiny_vit.rs
new file mode 100644
index 00000000..cd2936ab
--- /dev/null
+++ b/candle-transformers/src/models/segment_anything/tiny_vit.rs
@@ -0,0 +1,633 @@
+// Adapted from:
+// https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/modeling/tiny_vit_sam.py
+use candle::{IndexOp, Result, Tensor, D};
+use candle_nn::{Conv2dConfig, Module, VarBuilder};
+
+const MBCONV_EXPAND_RATIO: usize = 4;
+const MLP_RATIO: usize = 4;
+const LOCAL_CONV_SIZE: usize = 3;
+const IMG_SIZE: usize = 1024;
+const IN_CHANNELS: usize = 3;
+
+#[derive(Debug)]
+struct Conv2dBN {
+ c: candle_nn::Conv2d,
+ bn: candle_nn::BatchNorm,
+ span: tracing::Span,
+}
+
+impl Conv2dBN {
+ fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> {
+ let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?;
+ let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "conv2d-bn");
+ Ok(Self { c, bn, span })
+ }
+}
+
+impl Module for Conv2dBN {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ xs.apply(&self.c)?.apply(&self.bn)
+ }
+}
+
+#[derive(Debug)]
+struct PatchEmbed {
+ conv1: Conv2dBN,
+ conv2: Conv2dBN,
+ span: tracing::Span,
+}
+
+impl PatchEmbed {
+ fn new(in_chans: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
+ let cfg = candle_nn::Conv2dConfig {
+ stride: 2,
+ padding: 1,
+ ..Default::default()
+ };
+ let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?;
+ let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
+ Ok(Self { conv1, conv2, span })
+ }
+}
+
+impl Module for PatchEmbed {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2)
+ }
+}
+
+#[derive(Debug)]
+struct MBConv {
+ conv1: Conv2dBN,
+ conv2: Conv2dBN,
+ conv3: Conv2dBN,
+ span: tracing::Span,
+}
+
+impl MBConv {
+ fn new(in_: usize, out: usize, expand_ratio: usize, vb: VarBuilder) -> Result<Self> {
+ let hidden = in_ * expand_ratio;
+ let cfg2 = candle_nn::Conv2dConfig {
+ padding: 1,
+ groups: hidden,
+ ..Default::default()
+ };
+ let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?;
+ let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?;
+ let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "mb-conv");
+ Ok(Self {
+ conv1,
+ conv2,
+ conv3,
+ span,
+ })
+ }
+}
+
+impl Module for MBConv {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let shortcut = xs;
+ let xs = xs
+ .apply(&self.conv1)?
+ .gelu()?
+ .apply(&self.conv2)?
+ .gelu()?
+ .apply(&self.conv3)?;
+ (xs + shortcut)?.gelu()
+ }
+}
+
+#[derive(Debug)]
+struct PatchMerging {
+ conv1: Conv2dBN,
+ conv2: Conv2dBN,
+ conv3: Conv2dBN,
+ input_resolution: (usize, usize),
+ span: tracing::Span,
+}
+
+impl PatchMerging {
+ fn new(
+ input_resolution: (usize, usize),
+ dim: usize,
+ out: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let stride = if [320, 448, 576].contains(&out) { 1 } else { 2 };
+ let cfg2 = candle_nn::Conv2dConfig {
+ padding: 1,
+ stride,
+ groups: out,
+ ..Default::default()
+ };
+ let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?;
+ let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?;
+ let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "patch-merging");
+ Ok(Self {
+ conv1,
+ conv2,
+ conv3,
+ input_resolution,
+ span,
+ })
+ }
+}
+
+impl Module for PatchMerging {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let xs = if xs.rank() == 3 {
+ let (h, w) = self.input_resolution;
+ let b = xs.dim(0)?;
+ xs.reshape((b, h, w, ()))?.permute((0, 3, 1, 2))?
+ } else {
+ xs.clone()
+ };
+ xs.apply(&self.conv1)?
+ .gelu()?
+ .apply(&self.conv2)?
+ .gelu()?
+ .apply(&self.conv3)?
+ .flatten_from(2)?
+ .transpose(1, 2)
+ }
+}
+
+#[derive(Debug)]
+struct ConvLayer {
+ blocks: Vec<MBConv>,
+ downsample: Option<PatchMerging>,
+ span: tracing::Span,
+}
+
+impl ConvLayer {
+ fn new(
+ dim: usize,
+ out: usize,
+ input_resolution: (usize, usize),
+ depth: usize,
+ downsample: bool,
+ conv_expand_ratio: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let vb_b = vb.pp("blocks");
+ let mut blocks = Vec::with_capacity(depth);
+ for index in 0..depth {
+ let block = MBConv::new(dim, dim, conv_expand_ratio, vb_b.pp(index))?;
+ blocks.push(block)
+ }
+ let downsample = if downsample {
+ let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
+ Some(downsample)
+ } else {
+ None
+ };
+ let span = tracing::span!(tracing::Level::TRACE, "conv-layer");
+ Ok(Self {
+ blocks,
+ downsample,
+ span,
+ })
+ }
+}
+
+impl Module for ConvLayer {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let mut xs = xs.clone();
+ for block in self.blocks.iter() {
+ xs = block.forward(&xs)?
+ }
+ match &self.downsample {
+ None => Ok(xs),
+ Some(downsample) => downsample.forward(&xs),
+ }
+ }
+}
+
+#[derive(Debug)]
+struct Mlp {
+ norm: candle_nn::LayerNorm,
+ fc1: super::Linear,
+ fc2: super::Linear,
+ span: tracing::Span,
+}
+
+impl Mlp {
+ fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> {
+ let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?;
+ let fc1 = super::linear(vb.pp("fc1"), in_, hidden, true)?;
+ let fc2 = super::linear(vb.pp("fc2"), hidden, in_, true)?;
+ let span = tracing::span!(tracing::Level::TRACE, "mlp");
+ Ok(Self {
+ norm,
+ fc1,
+ fc2,
+ span,
+ })
+ }
+}
+
+impl Module for Mlp {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ xs.apply(&self.norm)?
+ .apply(&self.fc1)?
+ .gelu()?
+ .apply(&self.fc2)
+ }
+}
+
+#[derive(Debug)]
+struct Attention {
+ norm: candle_nn::LayerNorm,
+ qkv: super::Linear,
+ proj: super::Linear,
+ ab: Tensor,
+ key_dim: usize,
+ num_heads: usize,
+ d: usize,
+ dh: usize,
+ scale: f64,
+ span: tracing::Span,
+ span_matmul: tracing::Span,
+ span_softmax: tracing::Span,
+}
+
+impl Attention {
+ fn new(
+ dim: usize,
+ key_dim: usize,
+ num_heads: usize,
+ attn_ratio: usize,
+ resolution: (usize, usize),
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let d = attn_ratio * key_dim;
+ let dh = d * num_heads;
+ let nh_kd = key_dim * num_heads;
+ let h = dh + nh_kd * 2;
+ let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
+ let qkv = super::linear(vb.pp("qkv"), dim, h, true)?;
+ let proj = super::linear(vb.pp("proj"), dh, dim, true)?;
+
+ let points = (0..resolution.0)
+ .flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64)))
+ .collect::<Vec<_>>();
+ let mut idxs = Vec::with_capacity(points.len() * points.len());
+ let mut attention_offsets = std::collections::HashMap::new();
+ for &(x1, y1) in points.iter() {
+ for &(x2, y2) in points.iter() {
+ let offset = ((x2 - x1).abs(), (y2 - y1).abs());
+ let l = attention_offsets.len();
+ let idx = attention_offsets.entry(offset).or_insert(l);
+ idxs.push(*idx as u32)
+ }
+ }
+ let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?;
+ let idxs = Tensor::new(idxs, attention_biases.device())?;
+ let ab =
+ attention_biases
+ .index_select(&idxs, 1)?
+ .reshape(((), points.len(), points.len()))?;
+ let span = tracing::span!(tracing::Level::TRACE, "attention");
+ let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
+ let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
+ Ok(Self {
+ norm,
+ qkv,
+ proj,
+ ab,
+ key_dim,
+ num_heads,
+ d,
+ dh,
+ scale: 1f64 / (key_dim as f64).sqrt(),
+ span,
+ span_matmul,
+ span_softmax,
+ })
+ }
+}
+
+impl Module for Attention {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let (b, n, _) = xs.dims3()?;
+ let xs = xs.apply(&self.norm)?;
+ let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?;
+ let q = qkv
+ .narrow(D::Minus1, 0, self.key_dim)?
+ .permute((0, 2, 1, 3))?
+ .contiguous()?;
+ let k = qkv
+ .narrow(D::Minus1, self.key_dim, self.key_dim)?
+ .permute((0, 2, 1, 3))?
+ .contiguous()?;
+ let v = qkv
+ .narrow(D::Minus1, 2 * self.key_dim, self.d)?
+ .permute((0, 2, 1, 3))?
+ .contiguous()?;
+ let attn = {
+ let _enter = self.span_matmul.enter();
+ (q.matmul(&k.t()?)? * self.scale)?
+ };
+ let attn = attn.broadcast_add(&self.ab)?;
+ let attn = {
+ let _enter = self.span_softmax.enter();
+ candle_nn::ops::softmax_last_dim(&attn)?
+ };
+ let attn = {
+ let _enter = self.span_matmul.enter();
+ attn.matmul(&v)?
+ };
+ attn.transpose(1, 2)?
+ .reshape((b, n, self.dh))?
+ .apply(&self.proj)
+ }
+}
+
+#[derive(Debug)]
+struct TinyViTBlock {
+ attn: Attention,
+ local_conv: Conv2dBN,
+ mlp: Mlp,
+ window_size: usize,
+ input_resolution: (usize, usize),
+ span: tracing::Span,
+}
+
+impl TinyViTBlock {
+ fn new(
+ dim: usize,
+ input_resolution: (usize, usize),
+ num_heads: usize,
+ window_size: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let head_dim = dim / num_heads;
+ let attn = Attention::new(
+ dim,
+ head_dim,
+ num_heads,
+ 1,
+ (window_size, window_size),
+ vb.pp("attn"),
+ )?;
+ let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?;
+ let cfg = candle_nn::Conv2dConfig {
+ padding: LOCAL_CONV_SIZE / 2,
+ groups: dim,
+ ..Default::default()
+ };
+ let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "attention");
+ Ok(Self {
+ attn,
+ local_conv,
+ mlp,
+ window_size,
+ input_resolution,
+ span,
+ })
+ }
+}
+
+impl Module for TinyViTBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let (h, w) = self.input_resolution;
+ let (b, l, c) = xs.dims3()?;
+ let res_x = xs;
+ let xs = if h == self.window_size && w == self.window_size {
+ self.attn.forward(xs)?
+ } else {
+ let xs = xs.reshape((b, h, w, c))?;
+ let pad_b = (self.window_size - h % self.window_size) % self.window_size;
+ let pad_r = (self.window_size - w % self.window_size) % self.window_size;
+
+ let xs = if pad_b > 0 {
+ xs.pad_with_zeros(1, 0, pad_b)?
+ } else {
+ xs
+ };
+ let xs = if pad_r > 0 {
+ xs.pad_with_zeros(2, 0, pad_r)?
+ } else {
+ xs
+ };
+ let (p_h, p_w) = (h + pad_b, w + pad_r);
+ let n_h = p_h / self.window_size;
+ let n_w = p_w / self.window_size;
+ let xs = xs
+ .reshape((b, n_h, self.window_size, n_w, self.window_size, c))?
+ .transpose(2, 3)?
+ .reshape((b * n_h * n_w, self.window_size * self.window_size, c))?;
+ let xs = self.attn.forward(&xs)?;
+ let xs = xs
+ .reshape((b, n_h, n_w, self.window_size, self.window_size, c))?
+ .transpose(2, 3)?
+ .reshape((b, p_h, p_w, c))?;
+ let xs = if pad_r > 0 {
+ xs.i((.., .., ..w))?.contiguous()?
+ } else {
+ xs
+ };
+ let xs = if pad_b > 0 {
+ xs.i((.., ..h, ..))?.contiguous()?
+ } else {
+ xs
+ };
+ xs.reshape((b, l, c))?
+ };
+ let xs = (xs + res_x)?;
+ let xs = xs
+ .transpose(1, 2)?
+ .reshape((b, c, h, w))?
+ .apply(&self.local_conv)?
+ .reshape((b, c, l))?
+ .transpose(1, 2)?;
+ &xs + self.mlp.forward(&xs)?
+ }
+}
+
+#[derive(Debug)]
+struct BasicLayer {
+ blocks: Vec<TinyViTBlock>,
+ downsample: Option<PatchMerging>,
+ span: tracing::Span,
+}
+
+impl BasicLayer {
+ #[allow(clippy::too_many_arguments)]
+ fn new(
+ dim: usize,
+ input_resolution: (usize, usize),
+ depth: usize,
+ num_heads: usize,
+ window_size: usize,
+ downsample: bool,
+ out: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let vb_b = vb.pp("blocks");
+ let mut blocks = Vec::with_capacity(depth);
+ for index in 0..depth {
+ let block = TinyViTBlock::new(
+ dim,
+ input_resolution,
+ num_heads,
+ window_size,
+ vb_b.pp(index),
+ )?;
+ blocks.push(block)
+ }
+ let downsample = if downsample {
+ let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
+ Some(downsample)
+ } else {
+ None
+ };
+ let span = tracing::span!(tracing::Level::TRACE, "basic-layer");
+ Ok(Self {
+ blocks,
+ downsample,
+ span,
+ })
+ }
+}
+
+impl Module for BasicLayer {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let mut xs = xs.clone();
+ for block in self.blocks.iter() {
+ xs = block.forward(&xs)?
+ }
+ match &self.downsample {
+ None => Ok(xs),
+ Some(downsample) => downsample.forward(&xs),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct TinyViT {
+ patch_embed: PatchEmbed,
+ layer0: ConvLayer,
+ layers: Vec<BasicLayer>,
+ // norm_head: candle_nn::LayerNorm,
+ // head: candle_nn::Linear,
+ neck_conv1: candle_nn::Conv2d,
+ neck_ln1: super::LayerNorm2d,
+ neck_conv2: candle_nn::Conv2d,
+ neck_ln2: super::LayerNorm2d,
+ span: tracing::Span,
+ span_neck: tracing::Span,
+}
+
+impl TinyViT {
+ pub fn new(
+ embed_dims: &[usize],
+ depths: &[usize],
+ num_heads: &[usize],
+ window_sizes: &[usize],
+ _num_classes: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?;
+ let patches_resolution = IMG_SIZE / 4;
+
+ let vb_l = vb.pp("layers");
+ let layer0 = ConvLayer::new(
+ /* dim */ embed_dims[0],
+ /* out */ embed_dims[1],
+ /* input_resolution */ (patches_resolution, patches_resolution),
+ /* depth */ depths[0],
+ /* downsample */ true,
+ /* conv_expand_ratio */ MBCONV_EXPAND_RATIO,
+ vb_l.pp(0),
+ )?;
+
+ let num_layers = embed_dims.len();
+ let mut layers = Vec::with_capacity(num_layers - 1);
+ for i_layer in 1..num_layers {
+ let patches_resolution = patches_resolution / (1 << usize::min(i_layer, 2));
+ let layer = BasicLayer::new(
+ /* dim */ embed_dims[i_layer],
+ /* input_resolution */ (patches_resolution, patches_resolution),
+ /* depth */ depths[i_layer],
+ /* num_heads */ num_heads[i_layer],
+ /* window_size */ window_sizes[i_layer],
+ /* downsample */ i_layer < num_layers - 1,
+ /* out */ embed_dims[usize::min(i_layer + 1, num_layers - 1)],
+ vb_l.pp(i_layer),
+ )?;
+ layers.push(layer)
+ }
+
+ let last_embed_dim = embed_dims[embed_dims.len() - 1];
+ // let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?;
+ // let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?;
+ let neck_conv1 =
+ candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?;
+ let neck_ln1 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?;
+ let cfg = candle_nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?;
+ let neck_ln2 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?;
+
+ let span = tracing::span!(tracing::Level::TRACE, "tiny-vit");
+ let span_neck = tracing::span!(tracing::Level::TRACE, "neck");
+ Ok(Self {
+ patch_embed,
+ layer0,
+ layers,
+ neck_conv1,
+ neck_ln1,
+ neck_conv2,
+ neck_ln2,
+ span,
+ span_neck,
+ })
+ }
+}
+
+impl Module for TinyViT {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let xs = self.patch_embed.forward(xs)?;
+ let mut xs = self.layer0.forward(&xs)?;
+ for layer in self.layers.iter() {
+ xs = layer.forward(&xs)?
+ }
+ let (b, _, c) = xs.dims3()?;
+ let _enter = self.span_neck.enter();
+ xs.reshape((b, 64, 64, c))?
+ .permute((0, 3, 1, 2))?
+ .apply(&self.neck_conv1)?
+ .apply(&self.neck_ln1)?
+ .apply(&self.neck_conv2)?
+ .apply(&self.neck_ln2)
+ }
+}
+
+pub fn tiny_vit_5m(vb: VarBuilder) -> Result<TinyViT> {
+ TinyViT::new(
+ /* embed_dims */ &[64, 128, 160, 320],
+ /* depths */ &[2, 2, 6, 2],
+ /* num_heads */ &[2, 4, 5, 10],
+ /* window_sizes */ &[7, 7, 14, 7],
+ /* num_classes */ 1000,
+ vb,
+ )
+}
diff --git a/candle-transformers/src/models/segment_anything/transformer.rs b/candle-transformers/src/models/segment_anything/transformer.rs
new file mode 100644
index 00000000..80efb38c
--- /dev/null
+++ b/candle-transformers/src/models/segment_anything/transformer.rs
@@ -0,0 +1,221 @@
+use candle::{Result, Tensor};
+use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
+
+#[derive(Debug)]
+struct Attention {
+ q_proj: Linear,
+ k_proj: Linear,
+ v_proj: Linear,
+ out_proj: Linear,
+ num_heads: usize,
+}
+
+impl Attention {
+ fn new(
+ embedding_dim: usize,
+ num_heads: usize,
+ downsample_rate: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let internal_dim = embedding_dim / downsample_rate;
+ let q_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("q_proj"))?;
+ let k_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("k_proj"))?;
+ let v_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("v_proj"))?;
+ let out_proj = candle_nn::linear(internal_dim, embedding_dim, vb.pp("out_proj"))?;
+ Ok(Self {
+ q_proj,
+ k_proj,
+ v_proj,
+ out_proj,
+ num_heads,
+ })
+ }
+
+ fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
+ let (b, n, c) = x.dims3()?;
+ x.reshape((b, n, self.num_heads, c / self.num_heads))?
+ .transpose(1, 2)?
+ .contiguous()
+ }
+
+ fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
+ let (b, n_heads, n_tokens, c_per_head) = x.dims4()?;
+ x.transpose(1, 2)?
+ .reshape((b, n_tokens, n_heads * c_per_head))
+ }
+
+ fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
+ let q = self.q_proj.forward(&q.contiguous()?)?;
+ let k = self.k_proj.forward(&k.contiguous()?)?;
+ let v = self.v_proj.forward(&v.contiguous()?)?;
+
+ let q = self.separate_heads(&q)?;
+ let k = self.separate_heads(&k)?;
+ let v = self.separate_heads(&v)?;
+
+ let (_, _, _, c_per_head) = q.dims4()?;
+ let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?;
+ let attn = candle_nn::ops::softmax_last_dim(&attn)?;
+
+ let out = attn.matmul(&v)?;
+ self.recombine_heads(&out)?.apply(&self.out_proj)
+ }
+}
+
+#[derive(Debug)]
+struct TwoWayAttentionBlock {
+ self_attn: Attention,
+ norm1: LayerNorm,
+ cross_attn_token_to_image: Attention,
+ norm2: LayerNorm,
+ mlp: super::MlpBlock,
+ norm3: LayerNorm,
+ norm4: LayerNorm,
+ cross_attn_image_to_token: Attention,
+ skip_first_layer_pe: bool,
+}
+
+impl TwoWayAttentionBlock {
+ fn new(
+ embedding_dim: usize,
+ num_heads: usize,
+ mlp_dim: usize,
+ skip_first_layer_pe: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
+ let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
+ let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
+ let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?;
+ let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
+ let cross_attn_token_to_image = Attention::new(
+ embedding_dim,
+ num_heads,
+ 2,
+ vb.pp("cross_attn_token_to_image"),
+ )?;
+ let cross_attn_image_to_token = Attention::new(
+ embedding_dim,
+ num_heads,
+ 2,
+ vb.pp("cross_attn_image_to_token"),
+ )?;
+ let mlp = super::MlpBlock::new(
+ embedding_dim,
+ mlp_dim,
+ candle_nn::Activation::Relu,
+ vb.pp("mlp"),
+ )?;
+ Ok(Self {
+ self_attn,
+ norm1,
+ cross_attn_image_to_token,
+ norm2,
+ mlp,
+ norm3,
+ norm4,
+ cross_attn_token_to_image,
+ skip_first_layer_pe,
+ })
+ }
+
+ fn forward(
+ &self,
+ queries: &Tensor,
+ keys: &Tensor,
+ query_pe: &Tensor,
+ key_pe: &Tensor,
+ ) -> Result<(Tensor, Tensor)> {
+ // Self attention block
+ let queries = if self.skip_first_layer_pe {
+ self.self_attn.forward(queries, queries, queries)?
+ } else {
+ let q = (queries + query_pe)?;
+ let attn_out = self.self_attn.forward(&q, &q, queries)?;
+ (queries + attn_out)?
+ };
+ let queries = self.norm1.forward(&queries)?;
+
+ // Cross attention block, tokens attending to image embedding
+ let q = (&queries + query_pe)?;
+ let k = (keys + key_pe)?;
+ let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?;
+ let queries = (&queries + attn_out)?;
+ let queries = self.norm2.forward(&queries)?;
+
+ // MLP block
+ let mlp_out = self.mlp.forward(&queries);
+ let queries = (queries + mlp_out)?;
+ let queries = self.norm3.forward(&queries)?;
+
+ // Cross attention block, image embedding attending to tokens
+ let q = (&queries + query_pe)?;
+ let k = (keys + key_pe)?;
+ let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?;
+ let keys = (keys + attn_out)?;
+ let keys = self.norm4.forward(&keys)?;
+
+ Ok((queries, keys))
+ }
+}
+
+#[derive(Debug)]
+pub struct TwoWayTransformer {
+ layers: Vec<TwoWayAttentionBlock>,
+ final_attn_token_to_image: Attention,
+ norm_final_attn: LayerNorm,
+}
+
+impl TwoWayTransformer {
+ pub fn new(
+ depth: usize,
+ embedding_dim: usize,
+ num_heads: usize,
+ mlp_dim: usize,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let vb_l = vb.pp("layers");
+ let mut layers = Vec::with_capacity(depth);
+ for i in 0..depth {
+ let layer =
+ TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?;
+ layers.push(layer)
+ }
+ let final_attn_token_to_image = Attention::new(
+ embedding_dim,
+ num_heads,
+ 2,
+ vb.pp("final_attn_token_to_image"),
+ )?;
+ let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?;
+ Ok(Self {
+ layers,
+ final_attn_token_to_image,
+ norm_final_attn,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ image_embedding: &Tensor,
+ image_pe: &Tensor,
+ point_embedding: &Tensor,
+ ) -> Result<(Tensor, Tensor)> {
+ let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
+ let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;
+
+ let mut queries = point_embedding.clone();
+ let mut keys = image_embedding;
+
+ for layer in self.layers.iter() {
+ (queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)?
+ }
+
+ let q = (&queries + point_embedding)?;
+ let k = (&keys + image_pe)?;
+ let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?;
+ let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?;
+
+ Ok((queries, keys))
+ }
+}
diff --git a/candle-transformers/src/models/stable_diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs
new file mode 100644
index 00000000..b3ea91f9
--- /dev/null
+++ b/candle-transformers/src/models/stable_diffusion/attention.rs
@@ -0,0 +1,547 @@
+//! Attention Based Building Blocks
+use candle::{DType, IndexOp, Result, Tensor, D};
+use candle_nn as nn;
+use candle_nn::Module;
+
+#[derive(Debug)]
+struct GeGlu {
+ proj: nn::Linear,
+ span: tracing::Span,
+}
+
+impl GeGlu {
+ fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
+ let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "geglu");
+ Ok(Self { proj, span })
+ }
+}
+
+impl Module for GeGlu {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
+ &hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
+ }
+}
+
+/// A feed-forward layer.
+#[derive(Debug)]
+struct FeedForward {
+ project_in: GeGlu,
+ linear: nn::Linear,
+ span: tracing::Span,
+}
+
+impl FeedForward {
+ // The glu parameter in the python code is unused?
+ // https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L347
+ /// Creates a new feed-forward layer based on some given input dimension, some
+ /// output dimension, and a multiplier to be used for the intermediary layer.
+ fn new(vs: nn::VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> {
+ let inner_dim = dim * mult;
+ let dim_out = dim_out.unwrap_or(dim);
+ let vs = vs.pp("net");
+ let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
+ let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "ff");
+ Ok(Self {
+ project_in,
+ linear,
+ span,
+ })
+ }
+}
+
+impl Module for FeedForward {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let xs = self.project_in.forward(xs)?;
+ self.linear.forward(&xs)
+ }
+}
+
+#[cfg(feature = "flash-attn")]
+fn flash_attn(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ softmax_scale: f32,
+ causal: bool,
+) -> Result<Tensor> {
+ candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
+}
+
+#[cfg(not(feature = "flash-attn"))]
+fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
+ unimplemented!("compile with '--features flash-attn'")
+}
+
+#[derive(Debug)]
+pub struct CrossAttention {
+ to_q: nn::Linear,
+ to_k: nn::Linear,
+ to_v: nn::Linear,
+ to_out: nn::Linear,
+ heads: usize,
+ scale: f64,
+ slice_size: Option<usize>,
+ span: tracing::Span,
+ span_attn: tracing::Span,
+ span_softmax: tracing::Span,
+ use_flash_attn: bool,
+}
+
+impl CrossAttention {
+ // Defaults should be heads = 8, dim_head = 64, context_dim = None
+ pub fn new(
+ vs: nn::VarBuilder,
+ query_dim: usize,
+ context_dim: Option<usize>,
+ heads: usize,
+ dim_head: usize,
+ slice_size: Option<usize>,
+ use_flash_attn: bool,
+ ) -> Result<Self> {
+ let inner_dim = dim_head * heads;
+ let context_dim = context_dim.unwrap_or(query_dim);
+ let scale = 1.0 / f64::sqrt(dim_head as f64);
+ let to_q = nn::linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?;
+ let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?;
+ let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?;
+ let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "xa");
+ let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
+ let span_softmax = tracing::span!(tracing::Level::TRACE, "xa-softmax");
+ Ok(Self {
+ to_q,
+ to_k,
+ to_v,
+ to_out,
+ heads,
+ scale,
+ slice_size,
+ span,
+ span_attn,
+ span_softmax,
+ use_flash_attn,
+ })
+ }
+
+ fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
+ let (batch_size, seq_len, dim) = xs.dims3()?;
+ xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))?
+ .transpose(1, 2)?
+ .reshape((batch_size * self.heads, seq_len, dim / self.heads))
+ }
+
+ fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> {
+ let (batch_size, seq_len, dim) = xs.dims3()?;
+ xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))?
+ .transpose(1, 2)?
+ .reshape((batch_size / self.heads, seq_len, dim * self.heads))
+ }
+
+ fn sliced_attention(
+ &self,
+ query: &Tensor,
+ key: &Tensor,
+ value: &Tensor,
+ slice_size: usize,
+ ) -> Result<Tensor> {
+ let batch_size_attention = query.dim(0)?;
+ let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size);
+ let in_dtype = query.dtype();
+ let query = query.to_dtype(DType::F32)?;
+ let key = key.to_dtype(DType::F32)?;
+ let value = value.to_dtype(DType::F32)?;
+
+ for i in 0..batch_size_attention / slice_size {
+ let start_idx = i * slice_size;
+ let end_idx = (i + 1) * slice_size;
+
+ let xs = query
+ .i(start_idx..end_idx)?
+ .matmul(&(key.i(start_idx..end_idx)?.t()? * self.scale)?)?;
+ let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?;
+ hidden_states.push(xs)
+ }
+ let hidden_states = Tensor::stack(&hidden_states, 0)?.to_dtype(in_dtype)?;
+ self.reshape_batch_dim_to_heads(&hidden_states)
+ }
+
+ fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
+ let _enter = self.span_attn.enter();
+ let xs = if self.use_flash_attn {
+ let init_dtype = query.dtype();
+ let q = query
+ .to_dtype(candle::DType::F16)?
+ .unsqueeze(0)?
+ .transpose(1, 2)?;
+ let k = key
+ .to_dtype(candle::DType::F16)?
+ .unsqueeze(0)?
+ .transpose(1, 2)?;
+ let v = value
+ .to_dtype(candle::DType::F16)?
+ .unsqueeze(0)?
+ .transpose(1, 2)?;
+ flash_attn(&q, &k, &v, self.scale as f32, false)?
+ .transpose(1, 2)?
+ .squeeze(0)?
+ .to_dtype(init_dtype)?
+ } else {
+ let in_dtype = query.dtype();
+ let query = query.to_dtype(DType::F32)?;
+ let key = key.to_dtype(DType::F32)?;
+ let value = value.to_dtype(DType::F32)?;
+ let xs = query.matmul(&(key.t()? * self.scale)?)?;
+ let xs = {
+ let _enter = self.span_softmax.enter();
+ nn::ops::softmax_last_dim(&xs)?
+ };
+ xs.matmul(&value)?.to_dtype(in_dtype)?
+ };
+ self.reshape_batch_dim_to_heads(&xs)
+ }
+
+ pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let query = self.to_q.forward(xs)?;
+ let context = context.unwrap_or(xs).contiguous()?;
+ let key = self.to_k.forward(&context)?;
+ let value = self.to_v.forward(&context)?;
+ let query = self.reshape_heads_to_batch_dim(&query)?;
+ let key = self.reshape_heads_to_batch_dim(&key)?;
+ let value = self.reshape_heads_to_batch_dim(&value)?;
+ let dim0 = query.dim(0)?;
+ let slice_size = self.slice_size.and_then(|slice_size| {
+ if dim0 < slice_size {
+ None
+ } else {
+ Some(slice_size)
+ }
+ });
+ let xs = match slice_size {
+ None => self.attention(&query, &key, &value)?,
+ Some(slice_size) => self.sliced_attention(&query, &key, &value, slice_size)?,
+ };
+ self.to_out.forward(&xs)
+ }
+}
+
+/// A basic Transformer block.
+#[derive(Debug)]
+struct BasicTransformerBlock {
+ attn1: CrossAttention,
+ ff: FeedForward,
+ attn2: CrossAttention,
+ norm1: nn::LayerNorm,
+ norm2: nn::LayerNorm,
+ norm3: nn::LayerNorm,
+ span: tracing::Span,
+}
+
+impl BasicTransformerBlock {
+ fn new(
+ vs: nn::VarBuilder,
+ dim: usize,
+ n_heads: usize,
+ d_head: usize,
+ context_dim: Option<usize>,
+ sliced_attention_size: Option<usize>,
+ use_flash_attn: bool,
+ ) -> Result<Self> {
+ let attn1 = CrossAttention::new(
+ vs.pp("attn1"),
+ dim,
+ None,
+ n_heads,
+ d_head,
+ sliced_attention_size,
+ use_flash_attn,
+ )?;
+ let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?;
+ let attn2 = CrossAttention::new(
+ vs.pp("attn2"),
+ dim,
+ context_dim,
+ n_heads,
+ d_head,
+ sliced_attention_size,
+ use_flash_attn,
+ )?;
+ let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
+ let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
+ let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "basic-transformer");
+ Ok(Self {
+ attn1,
+ ff,
+ attn2,
+ norm1,
+ norm2,
+ norm3,
+ span,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?;
+ let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?;
+ self.ff.forward(&self.norm3.forward(&xs)?)? + xs
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct SpatialTransformerConfig {
+ pub depth: usize,
+ pub num_groups: usize,
+ pub context_dim: Option<usize>,
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for SpatialTransformerConfig {
+ fn default() -> Self {
+ Self {
+ depth: 1,
+ num_groups: 32,
+ context_dim: None,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+enum Proj {
+ Conv2d(nn::Conv2d),
+ Linear(nn::Linear),
+}
+
+// Aka Transformer2DModel
+#[derive(Debug)]
+pub struct SpatialTransformer {
+ norm: nn::GroupNorm,
+ proj_in: Proj,
+ transformer_blocks: Vec<BasicTransformerBlock>,
+ proj_out: Proj,
+ span: tracing::Span,
+ pub config: SpatialTransformerConfig,
+}
+
+impl SpatialTransformer {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ n_heads: usize,
+ d_head: usize,
+ use_flash_attn: bool,
+ config: SpatialTransformerConfig,
+ ) -> Result<Self> {
+ let inner_dim = n_heads * d_head;
+ let norm = nn::group_norm(config.num_groups, in_channels, 1e-6, vs.pp("norm"))?;
+ let proj_in = if config.use_linear_projection {
+ Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_in"))?)
+ } else {
+ Proj::Conv2d(nn::conv2d(
+ in_channels,
+ inner_dim,
+ 1,
+ Default::default(),
+ vs.pp("proj_in"),
+ )?)
+ };
+ let mut transformer_blocks = vec![];
+ let vs_tb = vs.pp("transformer_blocks");
+ for index in 0..config.depth {
+ let tb = BasicTransformerBlock::new(
+ vs_tb.pp(&index.to_string()),
+ inner_dim,
+ n_heads,
+ d_head,
+ config.context_dim,
+ config.sliced_attention_size,
+ use_flash_attn,
+ )?;
+ transformer_blocks.push(tb)
+ }
+ let proj_out = if config.use_linear_projection {
+ Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_out"))?)
+ } else {
+ Proj::Conv2d(nn::conv2d(
+ inner_dim,
+ in_channels,
+ 1,
+ Default::default(),
+ vs.pp("proj_out"),
+ )?)
+ };
+ let span = tracing::span!(tracing::Level::TRACE, "spatial-transformer");
+ Ok(Self {
+ norm,
+ proj_in,
+ transformer_blocks,
+ proj_out,
+ span,
+ config,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let (batch, _channel, height, weight) = xs.dims4()?;
+ let residual = xs;
+ let xs = self.norm.forward(xs)?;
+ let (inner_dim, xs) = match &self.proj_in {
+ Proj::Conv2d(p) => {
+ let xs = p.forward(&xs)?;
+ let inner_dim = xs.dim(1)?;
+ let xs = xs
+ .transpose(1, 2)?
+ .t()?
+ .reshape((batch, height * weight, inner_dim))?;
+ (inner_dim, xs)
+ }
+ Proj::Linear(p) => {
+ let inner_dim = xs.dim(1)?;
+ let xs = xs
+ .transpose(1, 2)?
+ .t()?
+ .reshape((batch, height * weight, inner_dim))?;
+ (inner_dim, p.forward(&xs)?)
+ }
+ };
+ let mut xs = xs;
+ for block in self.transformer_blocks.iter() {
+ xs = block.forward(&xs, context)?
+ }
+ let xs = match &self.proj_out {
+ Proj::Conv2d(p) => p.forward(
+ &xs.reshape((batch, height, weight, inner_dim))?
+ .t()?
+ .transpose(1, 2)?,
+ )?,
+ Proj::Linear(p) => p
+ .forward(&xs)?
+ .reshape((batch, height, weight, inner_dim))?
+ .t()?
+ .transpose(1, 2)?,
+ };
+ xs + residual
+ }
+}
+
+/// Configuration for an attention block.
+#[derive(Debug, Clone, Copy)]
+pub struct AttentionBlockConfig {
+ pub num_head_channels: Option<usize>,
+ pub num_groups: usize,
+ pub rescale_output_factor: f64,
+ pub eps: f64,
+}
+
+impl Default for AttentionBlockConfig {
+ fn default() -> Self {
+ Self {
+ num_head_channels: None,
+ num_groups: 32,
+ rescale_output_factor: 1.,
+ eps: 1e-5,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct AttentionBlock {
+ group_norm: nn::GroupNorm,
+ query: nn::Linear,
+ key: nn::Linear,
+ value: nn::Linear,
+ proj_attn: nn::Linear,
+ channels: usize,
+ num_heads: usize,
+ span: tracing::Span,
+ config: AttentionBlockConfig,
+}
+
+impl AttentionBlock {
+ pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result<Self> {
+ let num_head_channels = config.num_head_channels.unwrap_or(channels);
+ let num_heads = channels / num_head_channels;
+ let group_norm =
+ nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
+ let (q_path, k_path, v_path, out_path) = if vs.contains_tensor("to_q.weight") {
+ ("to_q", "to_k", "to_v", "to_out.0")
+ } else {
+ ("query", "key", "value", "proj_attn")
+ };
+ let query = nn::linear(channels, channels, vs.pp(q_path))?;
+ let key = nn::linear(channels, channels, vs.pp(k_path))?;
+ let value = nn::linear(channels, channels, vs.pp(v_path))?;
+ let proj_attn = nn::linear(channels, channels, vs.pp(out_path))?;
+ let span = tracing::span!(tracing::Level::TRACE, "attn-block");
+ Ok(Self {
+ group_norm,
+ query,
+ key,
+ value,
+ proj_attn,
+ channels,
+ num_heads,
+ span,
+ config,
+ })
+ }
+
+ fn transpose_for_scores(&self, xs: Tensor) -> Result<Tensor> {
+ let (batch, t, h_times_d) = xs.dims3()?;
+ xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
+ .transpose(1, 2)
+ }
+}
+
+impl Module for AttentionBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let in_dtype = xs.dtype();
+ let residual = xs;
+ let (batch, channel, height, width) = xs.dims4()?;
+ let xs = self
+ .group_norm
+ .forward(xs)?
+ .reshape((batch, channel, height * width))?
+ .transpose(1, 2)?;
+
+ let query_proj = self.query.forward(&xs)?;
+ let key_proj = self.key.forward(&xs)?;
+ let value_proj = self.value.forward(&xs)?;
+
+ let query_states = self
+ .transpose_for_scores(query_proj)?
+ .to_dtype(DType::F32)?;
+ let key_states = self.transpose_for_scores(key_proj)?.to_dtype(DType::F32)?;
+ let value_states = self
+ .transpose_for_scores(value_proj)?
+ .to_dtype(DType::F32)?;
+
+ let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25);
+ let attention_scores =
+ // TODO: Check that this needs two multiplication by `scale`.
+ (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?;
+ let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?;
+
+ let xs = attention_probs.matmul(&value_states.contiguous()?)?;
+ let xs = xs.to_dtype(in_dtype)?;
+ let xs = xs.transpose(1, 2)?.contiguous()?;
+ let xs = xs.flatten_from(D::Minus2)?;
+ let xs = self
+ .proj_attn
+ .forward(&xs)?
+ .t()?
+ .reshape((batch, channel, height, width))?;
+ (xs + residual)? / self.config.rescale_output_factor
+ }
+}
diff --git a/candle-transformers/src/models/stable_diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs
new file mode 100644
index 00000000..e7a20270
--- /dev/null
+++ b/candle-transformers/src/models/stable_diffusion/clip.rs
@@ -0,0 +1,389 @@
+//! Contrastive Language-Image Pre-Training
+//!
+//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
+//! pairs of images with related texts.
+//!
+//! https://github.com/openai/CLIP
+use candle::{DType, Device, Result, Tensor, D};
+use candle_nn as nn;
+use candle_nn::Module;
+
+#[derive(Debug, Clone, Copy)]
+pub enum Activation {
+ QuickGelu,
+ Gelu,
+ GeluErf,
+}
+
+impl Module for Activation {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ match self {
+ Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
+ Activation::Gelu => xs.gelu(),
+ Activation::GeluErf => xs.gelu_erf(),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Config {
+ vocab_size: usize,
+ embed_dim: usize, // aka config.hidden_size
+ activation: Activation, // aka config.hidden_act
+ intermediate_size: usize,
+ pub max_position_embeddings: usize,
+ // The character to use for padding, use EOS when not set.
+ pub pad_with: Option<String>,
+ num_hidden_layers: usize,
+ num_attention_heads: usize,
+ #[allow(dead_code)]
+ projection_dim: usize,
+}
+
+impl Config {
+ // The config details can be found in the "text_config" section of this json file:
+ // https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
+ pub fn v1_5() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 768,
+ intermediate_size: 3072,
+ max_position_embeddings: 77,
+ pad_with: None,
+ num_hidden_layers: 12,
+ num_attention_heads: 12,
+ projection_dim: 768,
+ activation: Activation::QuickGelu,
+ }
+ }
+
+ // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/text_encoder/config.json
+ pub fn v2_1() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 1024,
+ intermediate_size: 4096,
+ max_position_embeddings: 77,
+ pad_with: Some("!".to_string()),
+ num_hidden_layers: 23,
+ num_attention_heads: 16,
+ projection_dim: 512,
+ activation: Activation::Gelu,
+ }
+ }
+
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder/config.json
+ pub fn sdxl() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 768,
+ intermediate_size: 3072,
+ max_position_embeddings: 77,
+ pad_with: Some("!".to_string()),
+ num_hidden_layers: 12,
+ num_attention_heads: 12,
+ projection_dim: 768,
+ activation: Activation::QuickGelu,
+ }
+ }
+
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder_2/config.json
+ pub fn sdxl2() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 1280,
+ intermediate_size: 5120,
+ max_position_embeddings: 77,
+ pad_with: Some("!".to_string()),
+ num_hidden_layers: 32,
+ num_attention_heads: 20,
+ projection_dim: 1280,
+ activation: Activation::Gelu,
+ }
+ }
+
+ // https://huggingface.co/warp-ai/wuerstchen/blob/main/text_encoder/config.json
+ pub fn wuerstchen() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 1024,
+ intermediate_size: 4096,
+ max_position_embeddings: 77,
+ pad_with: None,
+ num_hidden_layers: 24,
+ num_attention_heads: 16,
+ projection_dim: 1024,
+ activation: Activation::GeluErf,
+ }
+ }
+
+ // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/text_encoder/config.json
+ pub fn wuerstchen_prior() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 1280,
+ intermediate_size: 5120,
+ max_position_embeddings: 77,
+ pad_with: None,
+ num_hidden_layers: 32,
+ num_attention_heads: 20,
+ projection_dim: 512,
+ activation: Activation::GeluErf,
+ }
+ }
+}
+
+// CLIP Text Model
+// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py
+#[derive(Debug)]
+struct ClipTextEmbeddings {
+ token_embedding: candle_nn::Embedding,
+ position_embedding: candle_nn::Embedding,
+ position_ids: Tensor,
+}
+
+impl ClipTextEmbeddings {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let token_embedding =
+ candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
+ let position_embedding = candle_nn::embedding(
+ c.max_position_embeddings,
+ c.embed_dim,
+ vs.pp("position_embedding"),
+ )?;
+ let position_ids =
+ Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
+ Ok(ClipTextEmbeddings {
+ token_embedding,
+ position_embedding,
+ position_ids,
+ })
+ }
+}
+
+impl Module for ClipTextEmbeddings {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let token_embedding = self.token_embedding.forward(xs)?;
+ let position_embedding = self.position_embedding.forward(&self.position_ids)?;
+ token_embedding.broadcast_add(&position_embedding)
+ }
+}
+
+#[derive(Debug)]
+struct ClipAttention {
+ k_proj: candle_nn::Linear,
+ v_proj: candle_nn::Linear,
+ q_proj: candle_nn::Linear,
+ out_proj: candle_nn::Linear,
+ head_dim: usize,
+ scale: f64,
+ num_attention_heads: usize,
+}
+
+impl ClipAttention {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let embed_dim = c.embed_dim;
+ let num_attention_heads = c.num_attention_heads;
+ let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
+ let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
+ let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
+ let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
+ let head_dim = embed_dim / num_attention_heads;
+ let scale = (head_dim as f64).powf(-0.5);
+ Ok(ClipAttention {
+ k_proj,
+ v_proj,
+ q_proj,
+ out_proj,
+ head_dim,
+ scale,
+ num_attention_heads,
+ })
+ }
+
+ fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
+ xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
+ .transpose(1, 2)?
+ .contiguous()
+ }
+
+ fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
+ let in_dtype = xs.dtype();
+ let (bsz, seq_len, embed_dim) = xs.dims3()?;
+ let query_states = (self.q_proj.forward(xs)? * self.scale)?;
+ let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
+ let query_states = self
+ .shape(&query_states, seq_len, bsz)?
+ .reshape(proj_shape)?
+ .to_dtype(DType::F32)?;
+ let key_states = self
+ .shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
+ .reshape(proj_shape)?
+ .to_dtype(DType::F32)?;
+ let value_states = self
+ .shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
+ .reshape(proj_shape)?
+ .to_dtype(DType::F32)?;
+ let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
+
+ let src_len = key_states.dim(1)?;
+ let attn_weights = attn_weights
+ .reshape((bsz, self.num_attention_heads, seq_len, src_len))?
+ .broadcast_add(causal_attention_mask)?;
+ let attn_weights =
+ attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
+ let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
+
+ let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?;
+ let attn_output = attn_output
+ .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
+ .transpose(1, 2)?
+ .reshape((bsz, seq_len, embed_dim))?;
+ self.out_proj.forward(&attn_output)
+ }
+}
+
+#[derive(Debug)]
+struct ClipMlp {
+ fc1: candle_nn::Linear,
+ fc2: candle_nn::Linear,
+ activation: Activation,
+}
+
+impl ClipMlp {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let fc1 = candle_nn::linear(c.embed_dim, c.intermediate_size, vs.pp("fc1"))?;
+ let fc2 = candle_nn::linear(c.intermediate_size, c.embed_dim, vs.pp("fc2"))?;
+ Ok(ClipMlp {
+ fc1,
+ fc2,
+ activation: c.activation,
+ })
+ }
+}
+
+impl ClipMlp {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.fc1.forward(xs)?;
+ self.fc2.forward(&self.activation.forward(&xs)?)
+ }
+}
+
+#[derive(Debug)]
+struct ClipEncoderLayer {
+ self_attn: ClipAttention,
+ layer_norm1: candle_nn::LayerNorm,
+ mlp: ClipMlp,
+ layer_norm2: candle_nn::LayerNorm,
+}
+
+impl ClipEncoderLayer {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
+ let layer_norm1 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm1"))?;
+ let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
+ let layer_norm2 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm2"))?;
+ Ok(ClipEncoderLayer {
+ self_attn,
+ layer_norm1,
+ mlp,
+ layer_norm2,
+ })
+ }
+
+ fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
+ let residual = xs;
+ let xs = self.layer_norm1.forward(xs)?;
+ let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
+ let xs = (xs + residual)?;
+
+ let residual = &xs;
+ let xs = self.layer_norm2.forward(&xs)?;
+ let xs = self.mlp.forward(&xs)?;
+ xs + residual
+ }
+}
+
+#[derive(Debug)]
+struct ClipEncoder {
+ layers: Vec<ClipEncoderLayer>,
+}
+
+impl ClipEncoder {
+ fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let vs = vs.pp("layers");
+ let mut layers: Vec<ClipEncoderLayer> = Vec::new();
+ for index in 0..c.num_hidden_layers {
+ let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
+ layers.push(layer)
+ }
+ Ok(ClipEncoder { layers })
+ }
+
+ fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
+ let mut xs = xs.clone();
+ for layer in self.layers.iter() {
+ xs = layer.forward(&xs, causal_attention_mask)?;
+ }
+ Ok(xs)
+ }
+}
+
+/// A CLIP transformer based model.
+#[derive(Debug)]
+pub struct ClipTextTransformer {
+ embeddings: ClipTextEmbeddings,
+ encoder: ClipEncoder,
+ final_layer_norm: candle_nn::LayerNorm,
+}
+
+impl ClipTextTransformer {
+ pub fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
+ let vs = vs.pp("text_model");
+ let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
+ let encoder = ClipEncoder::new(vs.pp("encoder"), c)?;
+ let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
+ Ok(ClipTextTransformer {
+ embeddings,
+ encoder,
+ final_layer_norm,
+ })
+ }
+
+ // https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
+ fn build_causal_attention_mask(
+ bsz: usize,
+ seq_len: usize,
+ mask_after: usize,
+ device: &Device,
+ ) -> Result<Tensor> {
+ let mask: Vec<_> = (0..seq_len)
+ .flat_map(|i| {
+ (0..seq_len).map(move |j| {
+ if j > i || j > mask_after {
+ f32::MIN
+ } else {
+ 0.
+ }
+ })
+ })
+ .collect();
+ let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
+ mask.broadcast_as((bsz, seq_len, seq_len))
+ }
+
+ pub fn forward_with_mask(&self, xs: &Tensor, mask_after: usize) -> Result<Tensor> {
+ let (bsz, seq_len) = xs.dims2()?;
+ let xs = self.embeddings.forward(xs)?;
+ let causal_attention_mask =
+ Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?;
+ let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
+ self.final_layer_norm.forward(&xs)
+ }
+}
+
+impl Module for ClipTextTransformer {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ self.forward_with_mask(xs, usize::MAX)
+ }
+}
diff --git a/candle-transformers/src/models/stable_diffusion/ddim.rs b/candle-transformers/src/models/stable_diffusion/ddim.rs
new file mode 100644
index 00000000..916b7349
--- /dev/null
+++ b/candle-transformers/src/models/stable_diffusion/ddim.rs
@@ -0,0 +1,180 @@
+//! # Denoising Diffusion Implicit Models
+//!
+//! The Denoising Diffusion Implicit Models (DDIM) is a simple scheduler
+//! similar to Denoising Diffusion Probabilistic Models (DDPM). The DDPM
+//! generative process is the reverse of a Markovian process, DDIM generalizes
+//! this to non-Markovian guidance.
+//!
+//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
+//! https://arxiv.org/abs/2010.02502
+use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
+use candle::{Result, Tensor};
+
+/// The configuration for the DDIM scheduler.
+#[derive(Debug, Clone, Copy)]
+pub struct DDIMSchedulerConfig {
+ /// The value of beta at the beginning of training.
+ pub beta_start: f64,
+ /// The value of beta at the end of training.
+ pub beta_end: f64,
+ /// How beta evolved during training.
+ pub beta_schedule: BetaSchedule,
+ /// The amount of noise to be added at each step.
+ pub eta: f64,
+ /// Adjust the indexes of the inference schedule by this value.
+ pub steps_offset: usize,
+ /// prediction type of the scheduler function, one of `epsilon` (predicting
+ /// the noise of the diffusion process), `sample` (directly predicting the noisy sample`)
+ /// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf)
+ pub prediction_type: PredictionType,
+ /// number of diffusion steps used to train the model
+ pub train_timesteps: usize,
+}
+
+impl Default for DDIMSchedulerConfig {
+ fn default() -> Self {
+ Self {
+ beta_start: 0.00085f64,
+ beta_end: 0.012f64,
+ beta_schedule: BetaSchedule::ScaledLinear,
+ eta: 0.,
+ steps_offset: 1,
+ prediction_type: PredictionType::Epsilon,
+ train_timesteps: 1000,
+ }
+ }
+}
+
+/// The DDIM scheduler.
+#[derive(Debug, Clone)]
+pub struct DDIMScheduler {
+ timesteps: Vec<usize>,
+ alphas_cumprod: Vec<f64>,
+ step_ratio: usize,
+ init_noise_sigma: f64,
+ pub config: DDIMSchedulerConfig,
+}
+
+// clip_sample: False, set_alpha_to_one: False
+impl DDIMScheduler {
+ /// Creates a new DDIM scheduler given the number of steps to be
+ /// used for inference as well as the number of steps that was used
+ /// during training.
+ pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
+ let step_ratio = config.train_timesteps / inference_steps;
+ let timesteps: Vec<usize> = (0..(inference_steps))
+ .map(|s| s * step_ratio + config.steps_offset)
+ .rev()
+ .collect();
+ let betas = match config.beta_schedule {
+ BetaSchedule::ScaledLinear => super::utils::linspace(
+ config.beta_start.sqrt(),
+ config.beta_end.sqrt(),
+ config.train_timesteps,
+ )?
+ .sqr()?,
+ BetaSchedule::Linear => {
+ super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
+ }
+ BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
+ };
+ let betas = betas.to_vec1::<f64>()?;
+ let mut alphas_cumprod = Vec::with_capacity(betas.len());
+ for &beta in betas.iter() {
+ let alpha = 1.0 - beta;
+ alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
+ }
+ Ok(Self {
+ alphas_cumprod,
+ timesteps,
+ step_ratio,
+ init_noise_sigma: 1.,
+ config,
+ })
+ }
+
+ pub fn timesteps(&self) -> &[usize] {
+ self.timesteps.as_slice()
+ }
+
+ /// Ensures interchangeability with schedulers that need to scale the denoising model input
+ /// depending on the current timestep.
+ pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
+ Ok(sample)
+ }
+
+ /// Performs a backward step during inference.
+ pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
+ let timestep = if timestep >= self.alphas_cumprod.len() {
+ timestep - 1
+ } else {
+ timestep
+ };
+ // https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195
+ let prev_timestep = if timestep > self.step_ratio {
+ timestep - self.step_ratio
+ } else {
+ 0
+ };
+
+ let alpha_prod_t = self.alphas_cumprod[timestep];
+ let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep];
+ let beta_prod_t = 1. - alpha_prod_t;
+ let beta_prod_t_prev = 1. - alpha_prod_t_prev;
+
+ let (pred_original_sample, pred_epsilon) = match self.config.prediction_type {
+ PredictionType::Epsilon => {
+ let pred_original_sample = ((sample - (model_output * beta_prod_t.sqrt())?)?
+ * (1. / alpha_prod_t.sqrt()))?;
+ (pred_original_sample, model_output.clone())
+ }
+ PredictionType::VPrediction => {
+ let pred_original_sample =
+ ((sample * alpha_prod_t.sqrt())? - (model_output * beta_prod_t.sqrt())?)?;
+ let pred_epsilon =
+ ((model_output * alpha_prod_t.sqrt())? + (sample * beta_prod_t.sqrt())?)?;
+ (pred_original_sample, pred_epsilon)
+ }
+ PredictionType::Sample => {
+ let pred_original_sample = model_output.clone();
+ let pred_epsilon = ((sample - &pred_original_sample * alpha_prod_t.sqrt())?
+ * (1. / beta_prod_t.sqrt()))?;
+ (pred_original_sample, pred_epsilon)
+ }
+ };
+
+ let variance = (beta_prod_t_prev / beta_prod_t) * (1. - alpha_prod_t / alpha_prod_t_prev);
+ let std_dev_t = self.config.eta * variance.sqrt();
+
+ let pred_sample_direction =
+ (pred_epsilon * (1. - alpha_prod_t_prev - std_dev_t * std_dev_t).sqrt())?;
+ let prev_sample =
+ ((pred_original_sample * alpha_prod_t_prev.sqrt())? + pred_sample_direction)?;
+ if self.config.eta > 0. {
+ &prev_sample
+ + Tensor::randn(
+ 0f32,
+ std_dev_t as f32,
+ prev_sample.shape(),
+ prev_sample.device(),
+ )?
+ } else {
+ Ok(prev_sample)
+ }
+ }
+
+ pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
+ let timestep = if timestep >= self.alphas_cumprod.len() {
+ timestep - 1
+ } else {
+ timestep
+ };
+ let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt();
+ let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt();
+ (original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)?
+ }
+
+ pub fn init_noise_sigma(&self) -> f64 {
+ self.init_noise_sigma
+ }
+}
diff --git a/candle-transformers/src/models/stable_diffusion/ddpm.rs b/candle-transformers/src/models/stable_diffusion/ddpm.rs
new file mode 100644
index 00000000..d393f39a
--- /dev/null
+++ b/candle-transformers/src/models/stable_diffusion/ddpm.rs
@@ -0,0 +1,205 @@
+use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
+use candle::{Result, Tensor};
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum DDPMVarianceType {
+ FixedSmall,
+ FixedSmallLog,
+ FixedLarge,
+ FixedLargeLog,
+ Learned,
+}
+
+impl Default for DDPMVarianceType {
+ fn default() -> Self {
+ Self::FixedSmall
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct DDPMSchedulerConfig {
+ /// The value of beta at the beginning of training.
+ pub beta_start: f64,
+ /// The value of beta at the end of training.
+ pub beta_end: f64,
+ /// How beta evolved during training.
+ pub beta_schedule: BetaSchedule,
+ /// Option to predicted sample between -1 and 1 for numerical stability.
+ pub clip_sample: bool,
+ /// Option to clip the variance used when adding noise to the denoised sample.
+ pub variance_type: DDPMVarianceType,
+ /// prediction type of the scheduler function
+ pub prediction_type: PredictionType,
+ /// number of diffusion steps used to train the model.
+ pub train_timesteps: usize,
+}
+
+impl Default for DDPMSchedulerConfig {
+ fn default() -> Self {
+ Self {
+ beta_start: 0.00085,
+ beta_end: 0.012,
+ beta_schedule: BetaSchedule::ScaledLinear,
+ clip_sample: false,
+ variance_type: DDPMVarianceType::FixedSmall,
+ prediction_type: PredictionType::Epsilon,
+ train_timesteps: 1000,
+ }
+ }
+}
+
+pub struct DDPMScheduler {
+ alphas_cumprod: Vec<f64>,
+ init_noise_sigma: f64,
+ timesteps: Vec<usize>,
+ step_ratio: usize,
+ pub config: DDPMSchedulerConfig,
+}
+
+impl DDPMScheduler {
+ pub fn new(inference_steps: usize, config: DDPMSchedulerConfig) -> Result<Self> {
+ let betas = match config.beta_schedule {
+ BetaSchedule::ScaledLinear => super::utils::linspace(
+ config.beta_start.sqrt(),
+ config.beta_end.sqrt(),
+ config.train_timesteps,
+ )?
+ .sqr()?,
+ BetaSchedule::Linear => {
+ super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
+ }
+ BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
+ };
+
+ let betas = betas.to_vec1::<f64>()?;
+ let mut alphas_cumprod = Vec::with_capacity(betas.len());
+ for &beta in betas.iter() {
+ let alpha = 1.0 - beta;
+ alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
+ }
+
+ // min(train_timesteps, inference_steps)
+ // https://github.com/huggingface/diffusers/blob/8331da46837be40f96fbd24de6a6fb2da28acd11/src/diffusers/schedulers/scheduling_ddpm.py#L187
+ let inference_steps = inference_steps.min(config.train_timesteps);
+ // arange the number of the scheduler's timesteps
+ let step_ratio = config.train_timesteps / inference_steps;
+ let timesteps: Vec<usize> = (0..inference_steps).map(|s| s * step_ratio).rev().collect();
+
+ Ok(Self {
+ alphas_cumprod,
+ init_noise_sigma: 1.0,
+ timesteps,
+ step_ratio,
+ config,
+ })
+ }
+
+ fn get_variance(&self, timestep: usize) -> f64 {
+ let prev_t = timestep as isize - self.step_ratio as isize;
+ let alpha_prod_t = self.alphas_cumprod[timestep];
+ let alpha_prod_t_prev = if prev_t >= 0 {
+ self.alphas_cumprod[prev_t as usize]
+ } else {
+ 1.0
+ };
+ let current_beta_t = 1. - alpha_prod_t / alpha_prod_t_prev;
+
+ // For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
+ // and sample from it to get previous sample
+ // x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
+ let variance = (1. - alpha_prod_t_prev) / (1. - alpha_prod_t) * current_beta_t;
+
+ // retrieve variance
+ match self.config.variance_type {
+ DDPMVarianceType::FixedSmall => variance.max(1e-20),
+ // for rl-diffuser https://arxiv.org/abs/2205.09991
+ DDPMVarianceType::FixedSmallLog => {
+ let variance = variance.max(1e-20).ln();
+ (variance * 0.5).exp()
+ }
+ DDPMVarianceType::FixedLarge => current_beta_t,
+ DDPMVarianceType::FixedLargeLog => current_beta_t.ln(),
+ DDPMVarianceType::Learned => variance,
+ }
+ }
+
+ pub fn timesteps(&self) -> &[usize] {
+ self.timesteps.as_slice()
+ }
+
+ /// Ensures interchangeability with schedulers that need to scale the denoising model input
+ /// depending on the current timestep.
+ pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
+ sample
+ }
+
+ pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
+ let prev_t = timestep as isize - self.step_ratio as isize;
+
+ // https://github.com/huggingface/diffusers/blob/df2b548e893ccb8a888467c2508756680df22821/src/diffusers/schedulers/scheduling_ddpm.py#L272
+ // 1. compute alphas, betas
+ let alpha_prod_t = self.alphas_cumprod[timestep];
+ let alpha_prod_t_prev = if prev_t >= 0 {
+ self.alphas_cumprod[prev_t as usize]
+ } else {
+ 1.0
+ };
+ let beta_prod_t = 1. - alpha_prod_t;
+ let beta_prod_t_prev = 1. - alpha_prod_t_prev;
+ let current_alpha_t = alpha_prod_t / alpha_prod_t_prev;
+ let current_beta_t = 1. - current_alpha_t;
+
+ // 2. compute predicted original sample from predicted noise also called "predicted x_0" of formula (15)
+ let mut pred_original_sample = match self.config.prediction_type {
+ PredictionType::Epsilon => {
+ ((sample - model_output * beta_prod_t.sqrt())? / alpha_prod_t.sqrt())?
+ }
+ PredictionType::Sample => model_output.clone(),
+ PredictionType::VPrediction => {
+ ((sample * alpha_prod_t.sqrt())? - model_output * beta_prod_t.sqrt())?
+ }
+ };
+
+ // 3. clip predicted x_0
+ if self.config.clip_sample {
+ pred_original_sample = pred_original_sample.clamp(-1f32, 1f32)?;
+ }
+
+ // 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
+ // See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ let pred_original_sample_coeff = (alpha_prod_t_prev.sqrt() * current_beta_t) / beta_prod_t;
+ let current_sample_coeff = current_alpha_t.sqrt() * beta_prod_t_prev / beta_prod_t;
+
+ // 5. Compute predicted previous sample µ_t
+ // See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ let pred_prev_sample = ((&pred_original_sample * pred_original_sample_coeff)?
+ + sample * current_sample_coeff)?;
+
+ // https://github.com/huggingface/diffusers/blob/df2b548e893ccb8a888467c2508756680df22821/src/diffusers/schedulers/scheduling_ddpm.py#L305
+ // 6. Add noise
+ let mut variance = model_output.zeros_like()?;
+ if timestep > 0 {
+ let variance_noise = model_output.randn_like(0., 1.)?;
+ if self.config.variance_type == DDPMVarianceType::FixedSmallLog {
+ variance = (variance_noise * self.get_variance(timestep))?;
+ } else {
+ variance = (variance_noise * self.get_variance(timestep).sqrt())?;
+ }
+ }
+ &pred_prev_sample + variance
+ }
+
+ pub fn add_noise(
+ &self,
+ original_samples: &Tensor,
+ noise: Tensor,
+ timestep: usize,
+ ) -> Result<Tensor> {
+ (original_samples * self.alphas_cumprod[timestep].sqrt())?
+ + noise * (1. - self.alphas_cumprod[timestep]).sqrt()
+ }
+
+ pub fn init_noise_sigma(&self) -> f64 {
+ self.init_noise_sigma
+ }
+}
diff --git a/candle-transformers/src/models/stable_diffusion/embeddings.rs b/candle-transformers/src/models/stable_diffusion/embeddings.rs
new file mode 100644
index 00000000..0de5f9a7
--- /dev/null
+++ b/candle-transformers/src/models/stable_diffusion/embeddings.rs
@@ -0,0 +1,65 @@
+use candle::{Result, Tensor, D};
+use candle_nn as nn;
+use candle_nn::Module;
+
+#[derive(Debug)]
+pub struct TimestepEmbedding {
+ linear_1: nn::Linear,
+ linear_2: nn::Linear,
+}
+
+impl TimestepEmbedding {
+ // act_fn: "silu"
+ pub fn new(vs: nn::VarBuilder, channel: usize, time_embed_dim: usize) -> Result<Self> {
+ let linear_1 = nn::linear(channel, time_embed_dim, vs.pp("linear_1"))?;
+ let linear_2 = nn::linear(time_embed_dim, time_embed_dim, vs.pp("linear_2"))?;
+ Ok(Self { linear_1, linear_2 })
+ }
+}
+
+impl Module for TimestepEmbedding {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?;
+ self.linear_2.forward(&xs)
+ }
+}
+
+#[derive(Debug)]
+pub struct Timesteps {
+ num_channels: usize,
+ flip_sin_to_cos: bool,
+ downscale_freq_shift: f64,
+}
+
+impl Timesteps {
+ pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self {
+ Self {
+ num_channels,
+ flip_sin_to_cos,
+ downscale_freq_shift,
+ }
+ }
+}
+
+impl Module for Timesteps {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let half_dim = (self.num_channels / 2) as u32;
+ let exponent = (Tensor::arange(0, half_dim, xs.device())?.to_dtype(candle::DType::F32)?
+ * -f64::ln(10000.))?;
+ let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
+ let emb = exponent.exp()?.to_dtype(xs.dtype())?;
+ // emb = timesteps[:, None].float() * emb[None, :]
+ let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
+ let (cos, sin) = (emb.cos()?, emb.sin()?);
+ let emb = if self.flip_sin_to_cos {
+ Tensor::cat(&[&cos, &sin], D::Minus1)?
+ } else {
+ Tensor::cat(&[&sin, &cos], D::Minus1)?
+ };
+ if self.num_channels % 2 == 1 {
+ emb.pad_with_zeros(D::Minus2, 0, 1)
+ } else {
+ Ok(emb)
+ }
+ }
+}
diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs
new file mode 100644
index 00000000..c6f1b904
--- /dev/null
+++ b/candle-transformers/src/models/stable_diffusion/mod.rs
@@ -0,0 +1,303 @@
+pub mod attention;
+pub mod clip;
+pub mod ddim;
+pub mod ddpm;
+pub mod embeddings;
+pub mod resnet;
+pub mod schedulers;
+pub mod unet_2d;
+pub mod unet_2d_blocks;
+pub mod utils;
+pub mod vae;
+
+use candle::{DType, Device, Result};
+use candle_nn as nn;
+
+#[derive(Clone, Debug)]
+pub struct StableDiffusionConfig {
+ pub width: usize,
+ pub height: usize,
+ pub clip: clip::Config,
+ pub clip2: Option<clip::Config>,
+ autoencoder: vae::AutoEncoderKLConfig,
+ unet: unet_2d::UNet2DConditionModelConfig,
+ scheduler: ddim::DDIMSchedulerConfig,
+}
+
+impl StableDiffusionConfig {
+ pub fn v1_5(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ ) -> Self {
+ let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ };
+ // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json
+ let unet = unet_2d::UNet2DConditionModelConfig {
+ blocks: vec![
+ bc(320, Some(1), 8),
+ bc(640, Some(1), 8),
+ bc(1280, Some(1), 8),
+ bc(1280, None, 8),
+ ],
+ center_input_sample: false,
+ cross_attention_dim: 768,
+ downsample_padding: 1,
+ flip_sin_to_cos: true,
+ freq_shift: 0.,
+ layers_per_block: 2,
+ mid_block_scale_factor: 1.,
+ norm_eps: 1e-5,
+ norm_num_groups: 32,
+ sliced_attention_size,
+ use_linear_projection: false,
+ };
+ let autoencoder = vae::AutoEncoderKLConfig {
+ block_out_channels: vec![128, 256, 512, 512],
+ layers_per_block: 2,
+ latent_channels: 4,
+ norm_num_groups: 32,
+ };
+ let height = if let Some(height) = height {
+ assert_eq!(height % 8, 0, "height has to be divisible by 8");
+ height
+ } else {
+ 512
+ };
+
+ let width = if let Some(width) = width {
+ assert_eq!(width % 8, 0, "width has to be divisible by 8");
+ width
+ } else {
+ 512
+ };
+
+ Self {
+ width,
+ height,
+ clip: clip::Config::v1_5(),
+ clip2: None,
+ autoencoder,
+ scheduler: Default::default(),
+ unet,
+ }
+ }
+
+ fn v2_1_(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ prediction_type: schedulers::PredictionType,
+ ) -> Self {
+ let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ };
+ // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json
+ let unet = unet_2d::UNet2DConditionModelConfig {
+ blocks: vec![
+ bc(320, Some(1), 5),
+ bc(640, Some(1), 10),
+ bc(1280, Some(1), 20),
+ bc(1280, None, 20),
+ ],
+ center_input_sample: false,
+ cross_attention_dim: 1024,
+ downsample_padding: 1,
+ flip_sin_to_cos: true,
+ freq_shift: 0.,
+ layers_per_block: 2,
+ mid_block_scale_factor: 1.,
+ norm_eps: 1e-5,
+ norm_num_groups: 32,
+ sliced_attention_size,
+ use_linear_projection: true,
+ };
+ // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/vae/config.json
+ let autoencoder = vae::AutoEncoderKLConfig {
+ block_out_channels: vec![128, 256, 512, 512],
+ layers_per_block: 2,
+ latent_channels: 4,
+ norm_num_groups: 32,
+ };
+ let scheduler = ddim::DDIMSchedulerConfig {
+ prediction_type,
+ ..Default::default()
+ };
+
+ let height = if let Some(height) = height {
+ assert_eq!(height % 8, 0, "height has to be divisible by 8");
+ height
+ } else {
+ 768
+ };
+
+ let width = if let Some(width) = width {
+ assert_eq!(width % 8, 0, "width has to be divisible by 8");
+ width
+ } else {
+ 768
+ };
+
+ Self {
+ width,
+ height,
+ clip: clip::Config::v2_1(),
+ clip2: None,
+ autoencoder,
+ scheduler,
+ unet,
+ }
+ }
+
+ pub fn v2_1(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ ) -> Self {
+ // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/scheduler/scheduler_config.json
+ Self::v2_1_(
+ sliced_attention_size,
+ height,
+ width,
+ schedulers::PredictionType::VPrediction,
+ )
+ }
+
+ fn sdxl_(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ prediction_type: schedulers::PredictionType,
+ ) -> Self {
+ let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ };
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
+ let unet = unet_2d::UNet2DConditionModelConfig {
+ blocks: vec![
+ bc(320, None, 5),
+ bc(640, Some(2), 10),
+ bc(1280, Some(10), 20),
+ ],
+ center_input_sample: false,
+ cross_attention_dim: 2048,
+ downsample_padding: 1,
+ flip_sin_to_cos: true,
+ freq_shift: 0.,
+ layers_per_block: 2,
+ mid_block_scale_factor: 1.,
+ norm_eps: 1e-5,
+ norm_num_groups: 32,
+ sliced_attention_size,
+ use_linear_projection: true,
+ };
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
+ let autoencoder = vae::AutoEncoderKLConfig {
+ block_out_channels: vec![128, 256, 512, 512],
+ layers_per_block: 2,
+ latent_channels: 4,
+ norm_num_groups: 32,
+ };
+ let scheduler = ddim::DDIMSchedulerConfig {
+ prediction_type,
+ ..Default::default()
+ };
+
+ let height = if let Some(height) = height {
+ assert_eq!(height % 8, 0, "height has to be divisible by 8");
+ height
+ } else {
+ 1024
+ };
+
+ let width = if let Some(width) = width {
+ assert_eq!(width % 8, 0, "width has to be divisible by 8");
+ width
+ } else {
+ 1024
+ };
+
+ Self {
+ width,
+ height,
+ clip: clip::Config::sdxl(),
+ clip2: Some(clip::Config::sdxl2()),
+ autoencoder,
+ scheduler,
+ unet,
+ }
+ }
+
+ pub fn sdxl(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ ) -> Self {
+ Self::sdxl_(
+ sliced_attention_size,
+ height,
+ width,
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json
+ schedulers::PredictionType::Epsilon,
+ )
+ }
+
+ pub fn build_vae<P: AsRef<std::path::Path>>(
+ &self,
+ vae_weights: P,
+ device: &Device,
+ dtype: DType,
+ ) -> Result<vae::AutoEncoderKL> {
+ let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? };
+ let weights = weights.deserialize()?;
+ let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
+ // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
+ let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?;
+ Ok(autoencoder)
+ }
+
+ pub fn build_unet<P: AsRef<std::path::Path>>(
+ &self,
+ unet_weights: P,
+ device: &Device,
+ in_channels: usize,
+ use_flash_attn: bool,
+ dtype: DType,
+ ) -> Result<unet_2d::UNet2DConditionModel> {
+ let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? };
+ let weights = weights.deserialize()?;
+ let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
+ let unet = unet_2d::UNet2DConditionModel::new(
+ vs_unet,
+ in_channels,
+ 4,
+ use_flash_attn,
+ self.unet.clone(),
+ )?;
+ Ok(unet)
+ }
+
+ pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
+ ddim::DDIMScheduler::new(n_steps, self.scheduler)
+ }
+}
+
+pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
+ clip: &clip::Config,
+ clip_weights: P,
+ device: &Device,
+ dtype: DType,
+) -> Result<clip::ClipTextTransformer> {
+ let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
+ let weights = weights.deserialize()?;
+ let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
+ let text_model = clip::ClipTextTransformer::new(vs, clip)?;
+ Ok(text_model)
+}
diff --git a/candle-transformers/src/models/stable_diffusion/resnet.rs b/candle-transformers/src/models/stable_diffusion/resnet.rs
new file mode 100644
index 00000000..0d818115
--- /dev/null
+++ b/candle-transformers/src/models/stable_diffusion/resnet.rs
@@ -0,0 +1,138 @@
+//! ResNet Building Blocks
+//!
+//! Some Residual Network blocks used in UNet models.
+//!
+//! Denoising Diffusion Implicit Models, K. He and al, 2015.
+//! https://arxiv.org/abs/1512.03385
+use super::utils::{conv2d, Conv2d};
+use candle::{Result, Tensor, D};
+use candle_nn as nn;
+use candle_nn::Module;
+
+/// Configuration for a ResNet block.
+#[derive(Debug, Clone, Copy)]
+pub struct ResnetBlock2DConfig {
+ /// The number of output channels, defaults to the number of input channels.
+ pub out_channels: Option<usize>,
+ pub temb_channels: Option<usize>,
+ /// The number of groups to use in group normalization.
+ pub groups: usize,
+ pub groups_out: Option<usize>,
+ /// The epsilon to be used in the group normalization operations.
+ pub eps: f64,
+ /// Whether to use a 2D convolution in the skip connection. When using None,
+ /// such a convolution is used if the number of input channels is different from
+ /// the number of output channels.
+ pub use_in_shortcut: Option<bool>,
+ // non_linearity: silu
+ /// The final output is scaled by dividing by this value.
+ pub output_scale_factor: f64,
+}
+
+impl Default for ResnetBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ out_channels: None,
+ temb_channels: Some(512),
+ groups: 32,
+ groups_out: None,
+ eps: 1e-6,
+ use_in_shortcut: None,
+ output_scale_factor: 1.,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct ResnetBlock2D {
+ norm1: nn::GroupNorm,
+ conv1: Conv2d,
+ norm2: nn::GroupNorm,
+ conv2: Conv2d,
+ time_emb_proj: Option<nn::Linear>,
+ conv_shortcut: Option<Conv2d>,
+ span: tracing::Span,
+ config: ResnetBlock2DConfig,
+}
+
+impl ResnetBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ config: ResnetBlock2DConfig,
+ ) -> Result<Self> {
+ let out_channels = config.out_channels.unwrap_or(in_channels);
+ let conv_cfg = nn::Conv2dConfig {
+ stride: 1,
+ padding: 1,
+ groups: 1,
+ dilation: 1,
+ };
+ let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
+ let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
+ let groups_out = config.groups_out.unwrap_or(config.groups);
+ let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?;
+ let conv2 = conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
+ let use_in_shortcut = config
+ .use_in_shortcut
+ .unwrap_or(in_channels != out_channels);
+ let conv_shortcut = if use_in_shortcut {
+ let conv_cfg = nn::Conv2dConfig {
+ stride: 1,
+ padding: 0,
+ groups: 1,
+ dilation: 1,
+ };
+ Some(conv2d(
+ in_channels,
+ out_channels,
+ 1,
+ conv_cfg,
+ vs.pp("conv_shortcut"),
+ )?)
+ } else {
+ None
+ };
+ let time_emb_proj = match config.temb_channels {
+ None => None,
+ Some(temb_channels) => Some(nn::linear(
+ temb_channels,
+ out_channels,
+ vs.pp("time_emb_proj"),
+ )?),
+ };
+ let span = tracing::span!(tracing::Level::TRACE, "resnet2d");
+ Ok(Self {
+ norm1,
+ conv1,
+ norm2,
+ conv2,
+ time_emb_proj,
+ span,
+ config,
+ conv_shortcut,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let shortcut_xs = match &self.conv_shortcut {
+ Some(conv_shortcut) => conv_shortcut.forward(xs)?,
+ None => xs.clone(),
+ };
+ let xs = self.norm1.forward(xs)?;
+ let xs = self.conv1.forward(&nn::ops::silu(&xs)?)?;
+ let xs = match (temb, &self.time_emb_proj) {
+ (Some(temb), Some(time_emb_proj)) => time_emb_proj
+ .forward(&nn::ops::silu(temb)?)?
+ .unsqueeze(D::Minus1)?
+ .unsqueeze(D::Minus1)?
+ .broadcast_add(&xs)?,
+ _ => xs,
+ };
+ let xs = self
+ .conv2
+ .forward(&nn::ops::silu(&self.norm2.forward(&xs)?)?)?;
+ (shortcut_xs + xs)? / self.config.output_scale_factor
+ }
+}
diff --git a/candle-transformers/src/models/stable_diffusion/schedulers.rs b/candle-transformers/src/models/stable_diffusion/schedulers.rs
new file mode 100644
index 00000000..3f6a1d72
--- /dev/null
+++ b/candle-transformers/src/models/stable_diffusion/schedulers.rs
@@ -0,0 +1,45 @@
+#![allow(dead_code)]
+//! # Diffusion pipelines and models
+//!
+//! Noise schedulers can be used to set the trade-off between
+//! inference speed and quality.
+
+use candle::{Result, Tensor};
+
+/// This represents how beta ranges from its minimum value to the maximum
+/// during training.
+#[derive(Debug, Clone, Copy)]
+pub enum BetaSchedule {
+ /// Linear interpolation.
+ Linear,
+ /// Linear interpolation of the square root of beta.
+ ScaledLinear,
+ /// Glide cosine schedule
+ SquaredcosCapV2,
+}
+
+#[derive(Debug, Clone, Copy)]
+pub enum PredictionType {
+ Epsilon,
+ VPrediction,
+ Sample,
+}
+
+/// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+/// `(1-beta)` over time from `t = [0,1]`.
+///
+/// Contains a function `alpha_bar` that takes an argument `t` and transforms it to the cumulative product of `(1-beta)`
+/// up to that part of the diffusion process.
+pub(crate) fn betas_for_alpha_bar(num_diffusion_timesteps: usize, max_beta: f64) -> Result<Tensor> {
+ let alpha_bar = |time_step: usize| {
+ f64::cos((time_step as f64 + 0.008) / 1.008 * std::f64::consts::FRAC_PI_2).powi(2)
+ };
+ let mut betas = Vec::with_capacity(num_diffusion_timesteps);
+ for i in 0..num_diffusion_timesteps {
+ let t1 = i / num_diffusion_timesteps;
+ let t2 = (i + 1) / num_diffusion_timesteps;
+ betas.push((1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta));
+ }
+ let betas_len = betas.len();
+ Tensor::from_vec(betas, betas_len, &candle::Device::Cpu)
+}
diff --git a/candle-transformers/src/models/stable_diffusion/unet_2d.rs b/candle-transformers/src/models/stable_diffusion/unet_2d.rs
new file mode 100644
index 00000000..a3ed136e
--- /dev/null
+++ b/candle-transformers/src/models/stable_diffusion/unet_2d.rs
@@ -0,0 +1,401 @@
+//! 2D UNet Denoising Models
+//!
+//! The 2D Unet models take as input a noisy sample and the current diffusion
+//! timestep and return a denoised version of the input.
+use super::embeddings::{TimestepEmbedding, Timesteps};
+use super::unet_2d_blocks::*;
+use super::utils::{conv2d, Conv2d};
+use candle::{Result, Tensor};
+use candle_nn as nn;
+use candle_nn::Module;
+
+#[derive(Debug, Clone, Copy)]
+pub struct BlockConfig {
+ pub out_channels: usize,
+ /// When `None` no cross-attn is used, when `Some(d)` then cross-attn is used and `d` is the
+ /// number of transformer blocks to be used.
+ pub use_cross_attn: Option<usize>,
+ pub attention_head_dim: usize,
+}
+
+#[derive(Debug, Clone)]
+pub struct UNet2DConditionModelConfig {
+ pub center_input_sample: bool,
+ pub flip_sin_to_cos: bool,
+ pub freq_shift: f64,
+ pub blocks: Vec<BlockConfig>,
+ pub layers_per_block: usize,
+ pub downsample_padding: usize,
+ pub mid_block_scale_factor: f64,
+ pub norm_num_groups: usize,
+ pub norm_eps: f64,
+ pub cross_attention_dim: usize,
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+}
+
+impl Default for UNet2DConditionModelConfig {
+ fn default() -> Self {
+ Self {
+ center_input_sample: false,
+ flip_sin_to_cos: true,
+ freq_shift: 0.,
+ blocks: vec![
+ BlockConfig {
+ out_channels: 320,
+ use_cross_attn: Some(1),
+ attention_head_dim: 8,
+ },
+ BlockConfig {
+ out_channels: 640,
+ use_cross_attn: Some(1),
+ attention_head_dim: 8,
+ },
+ BlockConfig {
+ out_channels: 1280,
+ use_cross_attn: Some(1),
+ attention_head_dim: 8,
+ },
+ BlockConfig {
+ out_channels: 1280,
+ use_cross_attn: None,
+ attention_head_dim: 8,
+ },
+ ],
+ layers_per_block: 2,
+ downsample_padding: 1,
+ mid_block_scale_factor: 1.,
+ norm_num_groups: 32,
+ norm_eps: 1e-5,
+ cross_attention_dim: 1280,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub(crate) enum UNetDownBlock {
+ Basic(DownBlock2D),
+ CrossAttn(CrossAttnDownBlock2D),
+}
+
+#[derive(Debug)]
+enum UNetUpBlock {
+ Basic(UpBlock2D),
+ CrossAttn(CrossAttnUpBlock2D),
+}
+
+#[derive(Debug)]
+pub struct UNet2DConditionModel {
+ conv_in: Conv2d,
+ time_proj: Timesteps,
+ time_embedding: TimestepEmbedding,
+ down_blocks: Vec<UNetDownBlock>,
+ mid_block: UNetMidBlock2DCrossAttn,
+ up_blocks: Vec<UNetUpBlock>,
+ conv_norm_out: nn::GroupNorm,
+ conv_out: Conv2d,
+ span: tracing::Span,
+ config: UNet2DConditionModelConfig,
+}
+
+impl UNet2DConditionModel {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ use_flash_attn: bool,
+ config: UNet2DConditionModelConfig,
+ ) -> Result<Self> {
+ let n_blocks = config.blocks.len();
+ let b_channels = config.blocks[0].out_channels;
+ let bl_channels = config.blocks.last().unwrap().out_channels;
+ let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim;
+ let time_embed_dim = b_channels * 4;
+ let conv_cfg = nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let conv_in = conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
+
+ let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift);
+ let time_embedding =
+ TimestepEmbedding::new(vs.pp("time_embedding"), b_channels, time_embed_dim)?;
+
+ let vs_db = vs.pp("down_blocks");
+ let down_blocks = (0..n_blocks)
+ .map(|i| {
+ let BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ } = config.blocks[i];
+
+ // Enable automatic attention slicing if the config sliced_attention_size is set to 0.
+ let sliced_attention_size = match config.sliced_attention_size {
+ Some(0) => Some(attention_head_dim / 2),
+ _ => config.sliced_attention_size,
+ };
+
+ let in_channels = if i > 0 {
+ config.blocks[i - 1].out_channels
+ } else {
+ b_channels
+ };
+ let db_cfg = DownBlock2DConfig {
+ num_layers: config.layers_per_block,
+ resnet_eps: config.norm_eps,
+ resnet_groups: config.norm_num_groups,
+ add_downsample: i < n_blocks - 1,
+ downsample_padding: config.downsample_padding,
+ ..Default::default()
+ };
+ if let Some(transformer_layers_per_block) = use_cross_attn {
+ let config = CrossAttnDownBlock2DConfig {
+ downblock: db_cfg,
+ attn_num_head_channels: attention_head_dim,
+ cross_attention_dim: config.cross_attention_dim,
+ sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ transformer_layers_per_block,
+ };
+ let block = CrossAttnDownBlock2D::new(
+ vs_db.pp(&i.to_string()),
+ in_channels,
+ out_channels,
+ Some(time_embed_dim),
+ use_flash_attn,
+ config,
+ )?;
+ Ok(UNetDownBlock::CrossAttn(block))
+ } else {
+ let block = DownBlock2D::new(
+ vs_db.pp(&i.to_string()),
+ in_channels,
+ out_channels,
+ Some(time_embed_dim),
+ db_cfg,
+ )?;
+ Ok(UNetDownBlock::Basic(block))
+ }
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ // https://github.com/huggingface/diffusers/blob/a76f2ad538e73b34d5fe7be08c8eb8ab38c7e90c/src/diffusers/models/unet_2d_condition.py#L462
+ let mid_transformer_layers_per_block = match config.blocks.last() {
+ None => 1,
+ Some(block) => block.use_cross_attn.unwrap_or(1),
+ };
+ let mid_cfg = UNetMidBlock2DCrossAttnConfig {
+ resnet_eps: config.norm_eps,
+ output_scale_factor: config.mid_block_scale_factor,
+ cross_attn_dim: config.cross_attention_dim,
+ attn_num_head_channels: bl_attention_head_dim,
+ resnet_groups: Some(config.norm_num_groups),
+ use_linear_projection: config.use_linear_projection,
+ transformer_layers_per_block: mid_transformer_layers_per_block,
+ ..Default::default()
+ };
+
+ let mid_block = UNetMidBlock2DCrossAttn::new(
+ vs.pp("mid_block"),
+ bl_channels,
+ Some(time_embed_dim),
+ use_flash_attn,
+ mid_cfg,
+ )?;
+
+ let vs_ub = vs.pp("up_blocks");
+ let up_blocks = (0..n_blocks)
+ .map(|i| {
+ let BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ } = config.blocks[n_blocks - 1 - i];
+
+ // Enable automatic attention slicing if the config sliced_attention_size is set to 0.
+ let sliced_attention_size = match config.sliced_attention_size {
+ Some(0) => Some(attention_head_dim / 2),
+ _ => config.sliced_attention_size,
+ };
+
+ let prev_out_channels = if i > 0 {
+ config.blocks[n_blocks - i].out_channels
+ } else {
+ bl_channels
+ };
+ let in_channels = {
+ let index = if i == n_blocks - 1 {
+ 0
+ } else {
+ n_blocks - i - 2
+ };
+ config.blocks[index].out_channels
+ };
+ let ub_cfg = UpBlock2DConfig {
+ num_layers: config.layers_per_block + 1,
+ resnet_eps: config.norm_eps,
+ resnet_groups: config.norm_num_groups,
+ add_upsample: i < n_blocks - 1,
+ ..Default::default()
+ };
+ if let Some(transformer_layers_per_block) = use_cross_attn {
+ let config = CrossAttnUpBlock2DConfig {
+ upblock: ub_cfg,
+ attn_num_head_channels: attention_head_dim,
+ cross_attention_dim: config.cross_attention_dim,
+ sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ transformer_layers_per_block,
+ };
+ let block = CrossAttnUpBlock2D::new(
+ vs_ub.pp(&i.to_string()),
+ in_channels,
+ prev_out_channels,
+ out_channels,
+ Some(time_embed_dim),
+ use_flash_attn,
+ config,
+ )?;
+ Ok(UNetUpBlock::CrossAttn(block))
+ } else {
+ let block = UpBlock2D::new(
+ vs_ub.pp(&i.to_string()),
+ in_channels,
+ prev_out_channels,
+ out_channels,
+ Some(time_embed_dim),
+ ub_cfg,
+ )?;
+ Ok(UNetUpBlock::Basic(block))
+ }
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ let conv_norm_out = nn::group_norm(
+ config.norm_num_groups,
+ b_channels,
+ config.norm_eps,
+ vs.pp("conv_norm_out"),
+ )?;
+ let conv_out = conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "unet2d");
+ Ok(Self {
+ conv_in,
+ time_proj,
+ time_embedding,
+ down_blocks,
+ mid_block,
+ up_blocks,
+ conv_norm_out,
+ conv_out,
+ span,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ timestep: f64,
+ encoder_hidden_states: &Tensor,
+ ) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None)
+ }
+
+ pub fn forward_with_additional_residuals(
+ &self,
+ xs: &Tensor,
+ timestep: f64,
+ encoder_hidden_states: &Tensor,
+ down_block_additional_residuals: Option<&[Tensor]>,
+ mid_block_additional_residual: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let (bsize, _channels, height, width) = xs.dims4()?;
+ let device = xs.device();
+ let n_blocks = self.config.blocks.len();
+ let num_upsamplers = n_blocks - 1;
+ let default_overall_up_factor = 2usize.pow(num_upsamplers as u32);
+ let forward_upsample_size =
+ height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0;
+ // 0. center input if necessary
+ let xs = if self.config.center_input_sample {
+ ((xs * 2.0)? - 1.0)?
+ } else {
+ xs.clone()
+ };
+ // 1. time
+ let emb = (Tensor::ones(bsize, xs.dtype(), device)? * timestep)?;
+ let emb = self.time_proj.forward(&emb)?;
+ let emb = self.time_embedding.forward(&emb)?;
+ // 2. pre-process
+ let xs = self.conv_in.forward(&xs)?;
+ // 3. down
+ let mut down_block_res_xs = vec![xs.clone()];
+ let mut xs = xs;
+ for down_block in self.down_blocks.iter() {
+ let (_xs, res_xs) = match down_block {
+ UNetDownBlock::Basic(b) => b.forward(&xs, Some(&emb))?,
+ UNetDownBlock::CrossAttn(b) => {
+ b.forward(&xs, Some(&emb), Some(encoder_hidden_states))?
+ }
+ };
+ down_block_res_xs.extend(res_xs);
+ xs = _xs;
+ }
+
+ let new_down_block_res_xs =
+ if let Some(down_block_additional_residuals) = down_block_additional_residuals {
+ let mut v = vec![];
+ // A previous version of this code had a bug because of the addition being made
+ // in place via += hence modifying the input of the mid block.
+ for (i, residuals) in down_block_additional_residuals.iter().enumerate() {
+ v.push((&down_block_res_xs[i] + residuals)?)
+ }
+ v
+ } else {
+ down_block_res_xs
+ };
+ let mut down_block_res_xs = new_down_block_res_xs;
+
+ // 4. mid
+ let xs = self
+ .mid_block
+ .forward(&xs, Some(&emb), Some(encoder_hidden_states))?;
+ let xs = match mid_block_additional_residual {
+ None => xs,
+ Some(m) => (m + xs)?,
+ };
+ // 5. up
+ let mut xs = xs;
+ let mut upsample_size = None;
+ for (i, up_block) in self.up_blocks.iter().enumerate() {
+ let n_resnets = match up_block {
+ UNetUpBlock::Basic(b) => b.resnets.len(),
+ UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(),
+ };
+ let res_xs = down_block_res_xs.split_off(down_block_res_xs.len() - n_resnets);
+ if i < n_blocks - 1 && forward_upsample_size {
+ let (_, _, h, w) = down_block_res_xs.last().unwrap().dims4()?;
+ upsample_size = Some((h, w))
+ }
+ xs = match up_block {
+ UNetUpBlock::Basic(b) => b.forward(&xs, &res_xs, Some(&emb), upsample_size)?,
+ UNetUpBlock::CrossAttn(b) => b.forward(
+ &xs,
+ &res_xs,
+ Some(&emb),
+ upsample_size,
+ Some(encoder_hidden_states),
+ )?,
+ };
+ }
+ // 6. post-process
+ let xs = self.conv_norm_out.forward(&xs)?;
+ let xs = nn::ops::silu(&xs)?;
+ self.conv_out.forward(&xs)
+ }
+}
diff --git a/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs
new file mode 100644
index 00000000..29510cef
--- /dev/null
+++ b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs
@@ -0,0 +1,868 @@
+//! 2D UNet Building Blocks
+//!
+use super::attention::{
+ AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
+};
+use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
+use super::utils::{conv2d, Conv2d};
+use candle::{Module, Result, Tensor, D};
+use candle_nn as nn;
+
+#[derive(Debug)]
+struct Downsample2D {
+ conv: Option<Conv2d>,
+ padding: usize,
+ span: tracing::Span,
+}
+
+impl Downsample2D {
+ fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ use_conv: bool,
+ out_channels: usize,
+ padding: usize,
+ ) -> Result<Self> {
+ let conv = if use_conv {
+ let config = nn::Conv2dConfig {
+ stride: 2,
+ padding,
+ ..Default::default()
+ };
+ let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
+ Some(conv)
+ } else {
+ None
+ };
+ let span = tracing::span!(tracing::Level::TRACE, "downsample2d");
+ Ok(Self {
+ conv,
+ padding,
+ span,
+ })
+ }
+}
+
+impl Module for Downsample2D {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ match &self.conv {
+ None => xs.avg_pool2d(2),
+ Some(conv) => {
+ if self.padding == 0 {
+ let xs = xs
+ .pad_with_zeros(D::Minus1, 0, 1)?
+ .pad_with_zeros(D::Minus2, 0, 1)?;
+ conv.forward(&xs)
+ } else {
+ conv.forward(xs)
+ }
+ }
+ }
+ }
+}
+
+// This does not support the conv-transpose mode.
+#[derive(Debug)]
+struct Upsample2D {
+ conv: Conv2d,
+ span: tracing::Span,
+}
+
+impl Upsample2D {
+ fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> {
+ let config = nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "upsample2d");
+ Ok(Self { conv, span })
+ }
+}
+
+impl Upsample2D {
+ fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let xs = match size {
+ None => {
+ let (_bsize, _channels, h, w) = xs.dims4()?;
+ xs.upsample_nearest2d(2 * h, 2 * w)?
+ }
+ Some((h, w)) => xs.upsample_nearest2d(h, w)?,
+ };
+ self.conv.forward(&xs)
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct DownEncoderBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_downsample: bool,
+ pub downsample_padding: usize,
+}
+
+impl Default for DownEncoderBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_downsample: true,
+ downsample_padding: 1,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct DownEncoderBlock2D {
+ resnets: Vec<ResnetBlock2D>,
+ downsampler: Option<Downsample2D>,
+ span: tracing::Span,
+ pub config: DownEncoderBlock2DConfig,
+}
+
+impl DownEncoderBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: DownEncoderBlock2DConfig,
+ ) -> Result<Self> {
+ let resnets: Vec<_> = {
+ let vs = vs.pp("resnets");
+ let conv_cfg = ResnetBlock2DConfig {
+ eps: config.resnet_eps,
+ out_channels: Some(out_channels),
+ groups: config.resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels: None,
+ ..Default::default()
+ };
+ (0..(config.num_layers))
+ .map(|i| {
+ let in_channels = if i == 0 { in_channels } else { out_channels };
+ ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?
+ };
+ let downsampler = if config.add_downsample {
+ let downsample = Downsample2D::new(
+ vs.pp("downsamplers").pp("0"),
+ out_channels,
+ true,
+ out_channels,
+ config.downsample_padding,
+ )?;
+ Some(downsample)
+ } else {
+ None
+ };
+ let span = tracing::span!(tracing::Level::TRACE, "down-enc2d");
+ Ok(Self {
+ resnets,
+ downsampler,
+ span,
+ config,
+ })
+ }
+}
+
+impl Module for DownEncoderBlock2D {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let mut xs = xs.clone();
+ for resnet in self.resnets.iter() {
+ xs = resnet.forward(&xs, None)?
+ }
+ match &self.downsampler {
+ Some(downsampler) => downsampler.forward(&xs),
+ None => Ok(xs),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UpDecoderBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_upsample: bool,
+}
+
+impl Default for UpDecoderBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_upsample: true,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UpDecoderBlock2D {
+ resnets: Vec<ResnetBlock2D>,
+ upsampler: Option<Upsample2D>,
+ span: tracing::Span,
+ pub config: UpDecoderBlock2DConfig,
+}
+
+impl UpDecoderBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: UpDecoderBlock2DConfig,
+ ) -> Result<Self> {
+ let resnets: Vec<_> = {
+ let vs = vs.pp("resnets");
+ let conv_cfg = ResnetBlock2DConfig {
+ out_channels: Some(out_channels),
+ eps: config.resnet_eps,
+ groups: config.resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels: None,
+ ..Default::default()
+ };
+ (0..(config.num_layers))
+ .map(|i| {
+ let in_channels = if i == 0 { in_channels } else { out_channels };
+ ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?
+ };
+ let upsampler = if config.add_upsample {
+ let upsample =
+ Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
+ Some(upsample)
+ } else {
+ None
+ };
+ let span = tracing::span!(tracing::Level::TRACE, "up-dec2d");
+ Ok(Self {
+ resnets,
+ upsampler,
+ span,
+ config,
+ })
+ }
+}
+
+impl Module for UpDecoderBlock2D {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let mut xs = xs.clone();
+ for resnet in self.resnets.iter() {
+ xs = resnet.forward(&xs, None)?
+ }
+ match &self.upsampler {
+ Some(upsampler) => upsampler.forward(&xs, None),
+ None => Ok(xs),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UNetMidBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: Option<usize>,
+ pub attn_num_head_channels: Option<usize>,
+ // attention_type "default"
+ pub output_scale_factor: f64,
+}
+
+impl Default for UNetMidBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: Some(32),
+ attn_num_head_channels: Some(1),
+ output_scale_factor: 1.,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UNetMidBlock2D {
+ resnet: ResnetBlock2D,
+ attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
+ span: tracing::Span,
+ pub config: UNetMidBlock2DConfig,
+}
+
+impl UNetMidBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ temb_channels: Option<usize>,
+ config: UNetMidBlock2DConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let vs_attns = vs.pp("attentions");
+ let resnet_groups = config
+ .resnet_groups
+ .unwrap_or_else(|| usize::min(in_channels / 4, 32));
+ let resnet_cfg = ResnetBlock2DConfig {
+ eps: config.resnet_eps,
+ groups: resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels,
+ ..Default::default()
+ };
+ let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
+ let attn_cfg = AttentionBlockConfig {
+ num_head_channels: config.attn_num_head_channels,
+ num_groups: resnet_groups,
+ rescale_output_factor: config.output_scale_factor,
+ eps: config.resnet_eps,
+ };
+ let mut attn_resnets = vec![];
+ for index in 0..config.num_layers {
+ let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?;
+ let resnet = ResnetBlock2D::new(
+ vs_resnets.pp(&(index + 1).to_string()),
+ in_channels,
+ resnet_cfg,
+ )?;
+ attn_resnets.push((attn, resnet))
+ }
+ let span = tracing::span!(tracing::Level::TRACE, "mid2d");
+ Ok(Self {
+ resnet,
+ attn_resnets,
+ span,
+ config,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let mut xs = self.resnet.forward(xs, temb)?;
+ for (attn, resnet) in self.attn_resnets.iter() {
+ xs = resnet.forward(&attn.forward(&xs)?, temb)?
+ }
+ Ok(xs)
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UNetMidBlock2DCrossAttnConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ pub resnet_groups: Option<usize>,
+ pub attn_num_head_channels: usize,
+ // attention_type "default"
+ pub output_scale_factor: f64,
+ pub cross_attn_dim: usize,
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+ pub transformer_layers_per_block: usize,
+}
+
+impl Default for UNetMidBlock2DCrossAttnConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: Some(32),
+ attn_num_head_channels: 1,
+ output_scale_factor: 1.,
+ cross_attn_dim: 1280,
+ sliced_attention_size: None, // Sliced attention disabled
+ use_linear_projection: false,
+ transformer_layers_per_block: 1,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UNetMidBlock2DCrossAttn {
+ resnet: ResnetBlock2D,
+ attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
+ span: tracing::Span,
+ pub config: UNetMidBlock2DCrossAttnConfig,
+}
+
+impl UNetMidBlock2DCrossAttn {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ temb_channels: Option<usize>,
+ use_flash_attn: bool,
+ config: UNetMidBlock2DCrossAttnConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let vs_attns = vs.pp("attentions");
+ let resnet_groups = config
+ .resnet_groups
+ .unwrap_or_else(|| usize::min(in_channels / 4, 32));
+ let resnet_cfg = ResnetBlock2DConfig {
+ eps: config.resnet_eps,
+ groups: resnet_groups,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels,
+ ..Default::default()
+ };
+ let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
+ let n_heads = config.attn_num_head_channels;
+ let attn_cfg = SpatialTransformerConfig {
+ depth: config.transformer_layers_per_block,
+ num_groups: resnet_groups,
+ context_dim: Some(config.cross_attn_dim),
+ sliced_attention_size: config.sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let mut attn_resnets = vec![];
+ for index in 0..config.num_layers {
+ let attn = SpatialTransformer::new(
+ vs_attns.pp(&index.to_string()),
+ in_channels,
+ n_heads,
+ in_channels / n_heads,
+ use_flash_attn,
+ attn_cfg,
+ )?;
+ let resnet = ResnetBlock2D::new(
+ vs_resnets.pp(&(index + 1).to_string()),
+ in_channels,
+ resnet_cfg,
+ )?;
+ attn_resnets.push((attn, resnet))
+ }
+ let span = tracing::span!(tracing::Level::TRACE, "xa-mid2d");
+ Ok(Self {
+ resnet,
+ attn_resnets,
+ span,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ temb: Option<&Tensor>,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let mut xs = self.resnet.forward(xs, temb)?;
+ for (attn, resnet) in self.attn_resnets.iter() {
+ xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
+ }
+ Ok(xs)
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct DownBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ // resnet_time_scale_shift: "default"
+ // resnet_act_fn: "swish"
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_downsample: bool,
+ pub downsample_padding: usize,
+}
+
+impl Default for DownBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_downsample: true,
+ downsample_padding: 1,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct DownBlock2D {
+ resnets: Vec<ResnetBlock2D>,
+ downsampler: Option<Downsample2D>,
+ span: tracing::Span,
+ pub config: DownBlock2DConfig,
+}
+
+impl DownBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ config: DownBlock2DConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let resnet_cfg = ResnetBlock2DConfig {
+ out_channels: Some(out_channels),
+ eps: config.resnet_eps,
+ output_scale_factor: config.output_scale_factor,
+ temb_channels,
+ ..Default::default()
+ };
+ let resnets = (0..config.num_layers)
+ .map(|i| {
+ let in_channels = if i == 0 { in_channels } else { out_channels };
+ ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let downsampler = if config.add_downsample {
+ let downsampler = Downsample2D::new(
+ vs.pp("downsamplers").pp("0"),
+ out_channels,
+ true,
+ out_channels,
+ config.downsample_padding,
+ )?;
+ Some(downsampler)
+ } else {
+ None
+ };
+ let span = tracing::span!(tracing::Level::TRACE, "down2d");
+ Ok(Self {
+ resnets,
+ downsampler,
+ span,
+ config,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
+ let _enter = self.span.enter();
+ let mut xs = xs.clone();
+ let mut output_states = vec![];
+ for resnet in self.resnets.iter() {
+ xs = resnet.forward(&xs, temb)?;
+ output_states.push(xs.clone());
+ }
+ let xs = match &self.downsampler {
+ Some(downsampler) => {
+ let xs = downsampler.forward(&xs)?;
+ output_states.push(xs.clone());
+ xs
+ }
+ None => xs,
+ };
+ Ok((xs, output_states))
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct CrossAttnDownBlock2DConfig {
+ pub downblock: DownBlock2DConfig,
+ pub attn_num_head_channels: usize,
+ pub cross_attention_dim: usize,
+ // attention_type: "default"
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+ pub transformer_layers_per_block: usize,
+}
+
+impl Default for CrossAttnDownBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ downblock: Default::default(),
+ attn_num_head_channels: 1,
+ cross_attention_dim: 1280,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ transformer_layers_per_block: 1,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct CrossAttnDownBlock2D {
+ downblock: DownBlock2D,
+ attentions: Vec<SpatialTransformer>,
+ span: tracing::Span,
+ pub config: CrossAttnDownBlock2DConfig,
+}
+
+impl CrossAttnDownBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ use_flash_attn: bool,
+ config: CrossAttnDownBlock2DConfig,
+ ) -> Result<Self> {
+ let downblock = DownBlock2D::new(
+ vs.clone(),
+ in_channels,
+ out_channels,
+ temb_channels,
+ config.downblock,
+ )?;
+ let n_heads = config.attn_num_head_channels;
+ let cfg = SpatialTransformerConfig {
+ depth: config.transformer_layers_per_block,
+ context_dim: Some(config.cross_attention_dim),
+ num_groups: config.downblock.resnet_groups,
+ sliced_attention_size: config.sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let vs_attn = vs.pp("attentions");
+ let attentions = (0..config.downblock.num_layers)
+ .map(|i| {
+ SpatialTransformer::new(
+ vs_attn.pp(&i.to_string()),
+ out_channels,
+ n_heads,
+ out_channels / n_heads,
+ use_flash_attn,
+ cfg,
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let span = tracing::span!(tracing::Level::TRACE, "xa-down2d");
+ Ok(Self {
+ downblock,
+ attentions,
+ span,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ temb: Option<&Tensor>,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<(Tensor, Vec<Tensor>)> {
+ let _enter = self.span.enter();
+ let mut output_states = vec![];
+ let mut xs = xs.clone();
+ for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
+ xs = resnet.forward(&xs, temb)?;
+ xs = attn.forward(&xs, encoder_hidden_states)?;
+ output_states.push(xs.clone());
+ }
+ let xs = match &self.downblock.downsampler {
+ Some(downsampler) => {
+ let xs = downsampler.forward(&xs)?;
+ output_states.push(xs.clone());
+ xs
+ }
+ None => xs,
+ };
+ Ok((xs, output_states))
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct UpBlock2DConfig {
+ pub num_layers: usize,
+ pub resnet_eps: f64,
+ // resnet_time_scale_shift: "default"
+ // resnet_act_fn: "swish"
+ pub resnet_groups: usize,
+ pub output_scale_factor: f64,
+ pub add_upsample: bool,
+}
+
+impl Default for UpBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ num_layers: 1,
+ resnet_eps: 1e-6,
+ resnet_groups: 32,
+ output_scale_factor: 1.,
+ add_upsample: true,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct UpBlock2D {
+ pub resnets: Vec<ResnetBlock2D>,
+ upsampler: Option<Upsample2D>,
+ span: tracing::Span,
+ pub config: UpBlock2DConfig,
+}
+
+impl UpBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ prev_output_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ config: UpBlock2DConfig,
+ ) -> Result<Self> {
+ let vs_resnets = vs.pp("resnets");
+ let resnet_cfg = ResnetBlock2DConfig {
+ out_channels: Some(out_channels),
+ temb_channels,
+ eps: config.resnet_eps,
+ output_scale_factor: config.output_scale_factor,
+ ..Default::default()
+ };
+ let resnets = (0..config.num_layers)
+ .map(|i| {
+ let res_skip_channels = if i == config.num_layers - 1 {
+ in_channels
+ } else {
+ out_channels
+ };
+ let resnet_in_channels = if i == 0 {
+ prev_output_channels
+ } else {
+ out_channels
+ };
+ let in_channels = resnet_in_channels + res_skip_channels;
+ ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let upsampler = if config.add_upsample {
+ let upsampler =
+ Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
+ Some(upsampler)
+ } else {
+ None
+ };
+ let span = tracing::span!(tracing::Level::TRACE, "up2d");
+ Ok(Self {
+ resnets,
+ upsampler,
+ span,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ res_xs: &[Tensor],
+ temb: Option<&Tensor>,
+ upsample_size: Option<(usize, usize)>,
+ ) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let mut xs = xs.clone();
+ for (index, resnet) in self.resnets.iter().enumerate() {
+ xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
+ xs = xs.contiguous()?;
+ xs = resnet.forward(&xs, temb)?;
+ }
+ match &self.upsampler {
+ Some(upsampler) => upsampler.forward(&xs, upsample_size),
+ None => Ok(xs),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct CrossAttnUpBlock2DConfig {
+ pub upblock: UpBlock2DConfig,
+ pub attn_num_head_channels: usize,
+ pub cross_attention_dim: usize,
+ // attention_type: "default"
+ pub sliced_attention_size: Option<usize>,
+ pub use_linear_projection: bool,
+ pub transformer_layers_per_block: usize,
+}
+
+impl Default for CrossAttnUpBlock2DConfig {
+ fn default() -> Self {
+ Self {
+ upblock: Default::default(),
+ attn_num_head_channels: 1,
+ cross_attention_dim: 1280,
+ sliced_attention_size: None,
+ use_linear_projection: false,
+ transformer_layers_per_block: 1,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct CrossAttnUpBlock2D {
+ pub upblock: UpBlock2D,
+ pub attentions: Vec<SpatialTransformer>,
+ span: tracing::Span,
+ pub config: CrossAttnUpBlock2DConfig,
+}
+
+impl CrossAttnUpBlock2D {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ prev_output_channels: usize,
+ out_channels: usize,
+ temb_channels: Option<usize>,
+ use_flash_attn: bool,
+ config: CrossAttnUpBlock2DConfig,
+ ) -> Result<Self> {
+ let upblock = UpBlock2D::new(
+ vs.clone(),
+ in_channels,
+ prev_output_channels,
+ out_channels,
+ temb_channels,
+ config.upblock,
+ )?;
+ let n_heads = config.attn_num_head_channels;
+ let cfg = SpatialTransformerConfig {
+ depth: config.transformer_layers_per_block,
+ context_dim: Some(config.cross_attention_dim),
+ num_groups: config.upblock.resnet_groups,
+ sliced_attention_size: config.sliced_attention_size,
+ use_linear_projection: config.use_linear_projection,
+ };
+ let vs_attn = vs.pp("attentions");
+ let attentions = (0..config.upblock.num_layers)
+ .map(|i| {
+ SpatialTransformer::new(
+ vs_attn.pp(&i.to_string()),
+ out_channels,
+ n_heads,
+ out_channels / n_heads,
+ use_flash_attn,
+ cfg,
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let span = tracing::span!(tracing::Level::TRACE, "xa-up2d");
+ Ok(Self {
+ upblock,
+ attentions,
+ span,
+ config,
+ })
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ res_xs: &[Tensor],
+ temb: Option<&Tensor>,
+ upsample_size: Option<(usize, usize)>,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let mut xs = xs.clone();
+ for (index, resnet) in self.upblock.resnets.iter().enumerate() {
+ xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
+ xs = xs.contiguous()?;
+ xs = resnet.forward(&xs, temb)?;
+ xs = self.attentions[index].forward(&xs, encoder_hidden_states)?;
+ }
+ match &self.upblock.upsampler {
+ Some(upsampler) => upsampler.forward(&xs, upsample_size),
+ None => Ok(xs),
+ }
+ }
+}
diff --git a/candle-transformers/src/models/stable_diffusion/utils.rs b/candle-transformers/src/models/stable_diffusion/utils.rs
new file mode 100644
index 00000000..c62f17af
--- /dev/null
+++ b/candle-transformers/src/models/stable_diffusion/utils.rs
@@ -0,0 +1,39 @@
+use candle::{Device, Result, Tensor};
+use candle_nn::Module;
+
+pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
+ if steps < 1 {
+ candle::bail!("cannot use linspace with steps {steps} <= 1")
+ }
+ let delta = (stop - start) / (steps - 1) as f64;
+ let vs = (0..steps)
+ .map(|step| start + step as f64 * delta)
+ .collect::<Vec<_>>();
+ Tensor::from_vec(vs, steps, &Device::Cpu)
+}
+
+// Wrap the conv2d op to provide some tracing.
+#[derive(Debug)]
+pub struct Conv2d {
+ inner: candle_nn::Conv2d,
+ span: tracing::Span,
+}
+
+impl Conv2d {
+ pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
+
+pub fn conv2d(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ cfg: candle_nn::Conv2dConfig,
+ vs: candle_nn::VarBuilder,
+) -> Result<Conv2d> {
+ let span = tracing::span!(tracing::Level::TRACE, "conv2d");
+ let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
+ Ok(Conv2d { inner, span })
+}
diff --git a/candle-transformers/src/models/stable_diffusion/vae.rs b/candle-transformers/src/models/stable_diffusion/vae.rs
new file mode 100644
index 00000000..21709afe
--- /dev/null
+++ b/candle-transformers/src/models/stable_diffusion/vae.rs
@@ -0,0 +1,380 @@
+#![allow(dead_code)]
+//! # Variational Auto-Encoder (VAE) Models.
+//!
+//! Auto-encoder models compress their input to a usually smaller latent space
+//! before expanding it back to its original shape. This results in the latent values
+//! compressing the original information.
+use super::unet_2d_blocks::{
+ DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig,
+ UpDecoderBlock2D, UpDecoderBlock2DConfig,
+};
+use candle::{Result, Tensor};
+use candle_nn as nn;
+use candle_nn::Module;
+
+#[derive(Debug, Clone)]
+struct EncoderConfig {
+ // down_block_types: DownEncoderBlock2D
+ block_out_channels: Vec<usize>,
+ layers_per_block: usize,
+ norm_num_groups: usize,
+ double_z: bool,
+}
+
+impl Default for EncoderConfig {
+ fn default() -> Self {
+ Self {
+ block_out_channels: vec![64],
+ layers_per_block: 2,
+ norm_num_groups: 32,
+ double_z: true,
+ }
+ }
+}
+
+#[derive(Debug)]
+struct Encoder {
+ conv_in: nn::Conv2d,
+ down_blocks: Vec<DownEncoderBlock2D>,
+ mid_block: UNetMidBlock2D,
+ conv_norm_out: nn::GroupNorm,
+ conv_out: nn::Conv2d,
+ #[allow(dead_code)]
+ config: EncoderConfig,
+}
+
+impl Encoder {
+ fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: EncoderConfig,
+ ) -> Result<Self> {
+ let conv_cfg = nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let conv_in = nn::conv2d(
+ in_channels,
+ config.block_out_channels[0],
+ 3,
+ conv_cfg,
+ vs.pp("conv_in"),
+ )?;
+ let mut down_blocks = vec![];
+ let vs_down_blocks = vs.pp("down_blocks");
+ for index in 0..config.block_out_channels.len() {
+ let out_channels = config.block_out_channels[index];
+ let in_channels = if index > 0 {
+ config.block_out_channels[index - 1]
+ } else {
+ config.block_out_channels[0]
+ };
+ let is_final = index + 1 == config.block_out_channels.len();
+ let cfg = DownEncoderBlock2DConfig {
+ num_layers: config.layers_per_block,
+ resnet_eps: 1e-6,
+ resnet_groups: config.norm_num_groups,
+ add_downsample: !is_final,
+ downsample_padding: 0,
+ ..Default::default()
+ };
+ let down_block = DownEncoderBlock2D::new(
+ vs_down_blocks.pp(&index.to_string()),
+ in_channels,
+ out_channels,
+ cfg,
+ )?;
+ down_blocks.push(down_block)
+ }
+ let last_block_out_channels = *config.block_out_channels.last().unwrap();
+ let mid_cfg = UNetMidBlock2DConfig {
+ resnet_eps: 1e-6,
+ output_scale_factor: 1.,
+ attn_num_head_channels: None,
+ resnet_groups: Some(config.norm_num_groups),
+ ..Default::default()
+ };
+ let mid_block =
+ UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
+ let conv_norm_out = nn::group_norm(
+ config.norm_num_groups,
+ last_block_out_channels,
+ 1e-6,
+ vs.pp("conv_norm_out"),
+ )?;
+ let conv_out_channels = if config.double_z {
+ 2 * out_channels
+ } else {
+ out_channels
+ };
+ let conv_cfg = nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let conv_out = nn::conv2d(
+ last_block_out_channels,
+ conv_out_channels,
+ 3,
+ conv_cfg,
+ vs.pp("conv_out"),
+ )?;
+ Ok(Self {
+ conv_in,
+ down_blocks,
+ mid_block,
+ conv_norm_out,
+ conv_out,
+ config,
+ })
+ }
+}
+
+impl Encoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = xs.apply(&self.conv_in)?;
+ for down_block in self.down_blocks.iter() {
+ xs = xs.apply(down_block)?
+ }
+ let xs = self
+ .mid_block
+ .forward(&xs, None)?
+ .apply(&self.conv_norm_out)?;
+ nn::ops::silu(&xs)?.apply(&self.conv_out)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct DecoderConfig {
+ // up_block_types: UpDecoderBlock2D
+ block_out_channels: Vec<usize>,
+ layers_per_block: usize,
+ norm_num_groups: usize,
+}
+
+impl Default for DecoderConfig {
+ fn default() -> Self {
+ Self {
+ block_out_channels: vec![64],
+ layers_per_block: 2,
+ norm_num_groups: 32,
+ }
+ }
+}
+
+#[derive(Debug)]
+struct Decoder {
+ conv_in: nn::Conv2d,
+ up_blocks: Vec<UpDecoderBlock2D>,
+ mid_block: UNetMidBlock2D,
+ conv_norm_out: nn::GroupNorm,
+ conv_out: nn::Conv2d,
+ #[allow(dead_code)]
+ config: DecoderConfig,
+}
+
+impl Decoder {
+ fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: DecoderConfig,
+ ) -> Result<Self> {
+ let n_block_out_channels = config.block_out_channels.len();
+ let last_block_out_channels = *config.block_out_channels.last().unwrap();
+ let conv_cfg = nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let conv_in = nn::conv2d(
+ in_channels,
+ last_block_out_channels,
+ 3,
+ conv_cfg,
+ vs.pp("conv_in"),
+ )?;
+ let mid_cfg = UNetMidBlock2DConfig {
+ resnet_eps: 1e-6,
+ output_scale_factor: 1.,
+ attn_num_head_channels: None,
+ resnet_groups: Some(config.norm_num_groups),
+ ..Default::default()
+ };
+ let mid_block =
+ UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
+ let mut up_blocks = vec![];
+ let vs_up_blocks = vs.pp("up_blocks");
+ let reversed_block_out_channels: Vec<_> =
+ config.block_out_channels.iter().copied().rev().collect();
+ for index in 0..n_block_out_channels {
+ let out_channels = reversed_block_out_channels[index];
+ let in_channels = if index > 0 {
+ reversed_block_out_channels[index - 1]
+ } else {
+ reversed_block_out_channels[0]
+ };
+ let is_final = index + 1 == n_block_out_channels;
+ let cfg = UpDecoderBlock2DConfig {
+ num_layers: config.layers_per_block + 1,
+ resnet_eps: 1e-6,
+ resnet_groups: config.norm_num_groups,
+ add_upsample: !is_final,
+ ..Default::default()
+ };
+ let up_block = UpDecoderBlock2D::new(
+ vs_up_blocks.pp(&index.to_string()),
+ in_channels,
+ out_channels,
+ cfg,
+ )?;
+ up_blocks.push(up_block)
+ }
+ let conv_norm_out = nn::group_norm(
+ config.norm_num_groups,
+ config.block_out_channels[0],
+ 1e-6,
+ vs.pp("conv_norm_out"),
+ )?;
+ let conv_cfg = nn::Conv2dConfig {
+ padding: 1,
+ ..Default::default()
+ };
+ let conv_out = nn::conv2d(
+ config.block_out_channels[0],
+ out_channels,
+ 3,
+ conv_cfg,
+ vs.pp("conv_out"),
+ )?;
+ Ok(Self {
+ conv_in,
+ up_blocks,
+ mid_block,
+ conv_norm_out,
+ conv_out,
+ config,
+ })
+ }
+}
+
+impl Decoder {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = self.mid_block.forward(&self.conv_in.forward(xs)?, None)?;
+ for up_block in self.up_blocks.iter() {
+ xs = up_block.forward(&xs)?
+ }
+ let xs = self.conv_norm_out.forward(&xs)?;
+ let xs = nn::ops::silu(&xs)?;
+ self.conv_out.forward(&xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct AutoEncoderKLConfig {
+ pub block_out_channels: Vec<usize>,
+ pub layers_per_block: usize,
+ pub latent_channels: usize,
+ pub norm_num_groups: usize,
+}
+
+impl Default for AutoEncoderKLConfig {
+ fn default() -> Self {
+ Self {
+ block_out_channels: vec![64],
+ layers_per_block: 1,
+ latent_channels: 4,
+ norm_num_groups: 32,
+ }
+ }
+}
+
+pub struct DiagonalGaussianDistribution {
+ mean: Tensor,
+ std: Tensor,
+}
+
+impl DiagonalGaussianDistribution {
+ pub fn new(parameters: &Tensor) -> Result<Self> {
+ let mut parameters = parameters.chunk(2, 1)?.into_iter();
+ let mean = parameters.next().unwrap();
+ let logvar = parameters.next().unwrap();
+ let std = (logvar * 0.5)?.exp()?;
+ Ok(DiagonalGaussianDistribution { mean, std })
+ }
+
+ pub fn sample(&self) -> Result<Tensor> {
+ let sample = self.mean.randn_like(0., 1.);
+ &self.mean + &self.std * sample
+ }
+}
+
+// https://github.com/huggingface/diffusers/blob/970e30606c2944e3286f56e8eb6d3dc6d1eb85f7/src/diffusers/models/vae.py#L485
+// This implementation is specific to the config used in stable-diffusion-v1-5
+// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
+#[derive(Debug)]
+pub struct AutoEncoderKL {
+ encoder: Encoder,
+ decoder: Decoder,
+ quant_conv: nn::Conv2d,
+ post_quant_conv: nn::Conv2d,
+ pub config: AutoEncoderKLConfig,
+}
+
+impl AutoEncoderKL {
+ pub fn new(
+ vs: nn::VarBuilder,
+ in_channels: usize,
+ out_channels: usize,
+ config: AutoEncoderKLConfig,
+ ) -> Result<Self> {
+ let latent_channels = config.latent_channels;
+ let encoder_cfg = EncoderConfig {
+ block_out_channels: config.block_out_channels.clone(),
+ layers_per_block: config.layers_per_block,
+ norm_num_groups: config.norm_num_groups,
+ double_z: true,
+ };
+ let encoder = Encoder::new(vs.pp("encoder"), in_channels, latent_channels, encoder_cfg)?;
+ let decoder_cfg = DecoderConfig {
+ block_out_channels: config.block_out_channels.clone(),
+ layers_per_block: config.layers_per_block,
+ norm_num_groups: config.norm_num_groups,
+ };
+ let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?;
+ let conv_cfg = Default::default();
+ let quant_conv = nn::conv2d(
+ 2 * latent_channels,
+ 2 * latent_channels,
+ 1,
+ conv_cfg,
+ vs.pp("quant_conv"),
+ )?;
+ let post_quant_conv = nn::conv2d(
+ latent_channels,
+ latent_channels,
+ 1,
+ conv_cfg,
+ vs.pp("post_quant_conv"),
+ )?;
+ Ok(Self {
+ encoder,
+ decoder,
+ quant_conv,
+ post_quant_conv,
+ config,
+ })
+ }
+
+ /// Returns the distribution in the latent space.
+ pub fn encode(&self, xs: &Tensor) -> Result<DiagonalGaussianDistribution> {
+ let xs = self.encoder.forward(xs)?;
+ let parameters = self.quant_conv.forward(&xs)?;
+ DiagonalGaussianDistribution::new(&parameters)
+ }
+
+ /// Takes as input some sampled values.
+ pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = self.post_quant_conv.forward(xs)?;
+ self.decoder.forward(&xs)
+ }
+}
diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs
new file mode 100644
index 00000000..539ae89b
--- /dev/null
+++ b/candle-transformers/src/models/t5.rs
@@ -0,0 +1,841 @@
+// T5 Text Model
+// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
+
+use candle::{DType, Device, Module, Result, Tensor, D};
+use candle_nn::{Activation, VarBuilder};
+use serde::Deserialize;
+use std::sync::Arc;
+
+#[derive(Debug)]
+struct Embedding {
+ inner: candle_nn::Embedding,
+ span: tracing::Span,
+}
+
+impl Embedding {
+ fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
+ let inner = candle_nn::embedding(d1, d2, vb)?;
+ let span = tracing::span!(tracing::Level::TRACE, "embedding");
+ Ok(Self { inner, span })
+ }
+
+ fn embeddings(&self) -> &Tensor {
+ self.inner.embeddings()
+ }
+}
+
+impl Module for Embedding {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(xs)
+ }
+}
+
+#[derive(Debug)]
+struct Linear {
+ inner: candle_nn::Linear,
+ span: tracing::Span,
+}
+
+impl Linear {
+ fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
+ let inner = candle_nn::linear_no_bias(d1, d2, vb)?;
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ Ok(Self { inner, span })
+ }
+}
+
+impl Module for Linear {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(xs)
+ }
+}
+
+fn default_relative_attention_max_distance() -> usize {
+ 128
+}
+
+fn default_is_decoder() -> bool {
+ false
+}
+
+fn default_use_cache() -> bool {
+ true
+}
+
+fn default_tie_word_embeddings() -> bool {
+ true
+}
+
+fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
+ let mask: Vec<_> = (0..size)
+ .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
+ .collect();
+ Tensor::from_slice(&mask, (size, size), device)
+}
+
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
+ let shape = mask.shape();
+ let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
+ let m = mask.where_cond(&on_true, on_false)?;
+ Ok(m)
+}
+
+#[derive(Debug, Clone, PartialEq, Deserialize)]
+pub struct Config {
+ vocab_size: usize,
+ d_model: usize,
+ d_kv: usize,
+ d_ff: usize,
+ num_layers: usize,
+ num_decoder_layers: Option<usize>,
+ num_heads: usize,
+ relative_attention_num_buckets: usize,
+ #[serde(default = "default_relative_attention_max_distance")]
+ relative_attention_max_distance: usize,
+ dropout_rate: f64,
+ layer_norm_epsilon: f64,
+ initializer_factor: f64,
+ #[serde(default)]
+ feed_forward_proj: Activation,
+ #[serde(default = "default_tie_word_embeddings")]
+ tie_word_embeddings: bool,
+ #[serde(default = "default_is_decoder")]
+ is_decoder: bool,
+ is_encoder_decoder: bool,
+ #[serde(default = "default_use_cache")]
+ pub use_cache: bool,
+ pub pad_token_id: usize,
+ pub eos_token_id: usize,
+}
+
+impl Default for Config {
+ fn default() -> Self {
+ Self {
+ vocab_size: 32128,
+ d_model: 512,
+ d_kv: 64,
+ d_ff: 2048,
+ num_layers: 6,
+ num_decoder_layers: None,
+ num_heads: 8,
+ relative_attention_num_buckets: 32,
+ relative_attention_max_distance: 128,
+ dropout_rate: 0.1,
+ layer_norm_epsilon: 1e-6,
+ initializer_factor: 1.0,
+ feed_forward_proj: Activation::Relu,
+ tie_word_embeddings: true,
+ is_decoder: false,
+ is_encoder_decoder: true,
+ use_cache: true,
+ pad_token_id: 0,
+ eos_token_id: 1,
+ }
+ }
+}
+
+impl Config {
+ // https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L184
+ pub fn musicgen_small() -> Self {
+ Self {
+ d_ff: 3072,
+ d_kv: 64,
+ d_model: 768,
+ dropout_rate: 0.1,
+ eos_token_id: 1,
+ feed_forward_proj: Activation::Relu,
+ tie_word_embeddings: true,
+ initializer_factor: 1.0,
+ is_decoder: false,
+ is_encoder_decoder: true,
+ layer_norm_epsilon: 1e-6,
+ num_decoder_layers: Some(12),
+ num_heads: 12,
+ num_layers: 12,
+ pad_token_id: 0,
+ relative_attention_max_distance: 128,
+ relative_attention_num_buckets: 32,
+ use_cache: true,
+ vocab_size: 32128,
+ }
+ }
+}
+
+#[derive(Debug)]
+struct T5LayerNorm {
+ weight: Tensor,
+ variance_epsilon: f64,
+ span: tracing::Span,
+}
+
+impl T5LayerNorm {
+ fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
+ let weight = vb.get(h, "weight")?;
+ Ok(Self {
+ weight,
+ variance_epsilon: eps,
+ span: tracing::span!(tracing::Level::TRACE, "layer-norm"),
+ })
+ }
+}
+
+impl Module for T5LayerNorm {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let dtype = xs.dtype();
+ let xs_f32 = xs.to_dtype(DType::F32)?;
+ // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
+ let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
+ let xs = xs.to_dtype(dtype)?;
+ let xs = xs.broadcast_mul(&self.weight)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug)]
+struct T5DenseActDense {
+ wi: Linear,
+ wo: Linear,
+ act: Activation,
+ span: tracing::Span,
+}
+
+impl T5DenseActDense {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let wi = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
+ let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
+ Ok(Self {
+ wi,
+ wo,
+ act: Activation::Relu,
+ span: tracing::span!(tracing::Level::TRACE, "dense-act-dense"),
+ })
+ }
+}
+
+impl Module for T5DenseActDense {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let xs = self.wi.forward(xs)?;
+ let xs = self.act.forward(&xs)?;
+ let xs = self.wo.forward(&xs)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug)]
+struct T5DenseGatedActDense {
+ wi_0: Linear,
+ wi_1: Linear,
+ wo: Linear,
+ act: Activation,
+ span: tracing::Span,
+}
+
+impl T5DenseGatedActDense {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let wi_0 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
+ let wi_1 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
+ let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
+ Ok(Self {
+ wi_0,
+ wi_1,
+ wo,
+ act: Activation::NewGelu,
+ span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
+ })
+ }
+}
+
+impl Module for T5DenseGatedActDense {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?;
+ let hidden_linear = self.wi_1.forward(xs)?;
+ let xs = hidden_gelu.broadcast_mul(&hidden_linear)?;
+ let xs = self.wo.forward(&xs)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug)]
+struct T5LayerFF {
+ dense_act: Option<T5DenseActDense>,
+ gated_dense_act: Option<T5DenseGatedActDense>,
+ layer_norm: T5LayerNorm,
+ span: tracing::Span,
+}
+
+impl T5LayerFF {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let layer_norm =
+ T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
+ let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
+ (
+ None,
+ Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
+ )
+ } else {
+ (
+ Some(T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?),
+ None,
+ )
+ };
+ Ok(Self {
+ dense_act,
+ gated_dense_act,
+ layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "layer-ff"),
+ })
+ }
+}
+
+impl Module for T5LayerFF {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let ys = self.layer_norm.forward(xs)?;
+ let ys = match &self.dense_act {
+ Some(dense_act) => dense_act.forward(&ys)?,
+ None => self.gated_dense_act.as_ref().unwrap().forward(&ys)?,
+ };
+ let xs = (xs + ys)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug)]
+struct T5Attention {
+ q: Linear,
+ k: Linear,
+ v: Linear,
+ o: Linear,
+ n_heads: usize,
+ d_kv: usize,
+ relative_attention_bias: Option<Embedding>,
+ relative_attention_num_buckets: usize,
+ relative_attention_max_distance: usize,
+ inner_dim: usize,
+ use_cache: bool,
+ kv_cache: Option<(Tensor, Tensor)>,
+ span: tracing::Span,
+ span_cache: tracing::Span,
+ span_mm: tracing::Span,
+ span_sm: tracing::Span,
+}
+
+impl T5Attention {
+ fn load(
+ has_relative_attention_bias: bool,
+ decoder: bool,
+ vb: VarBuilder,
+ cfg: &Config,
+ ) -> Result<Self> {
+ let inner_dim = cfg.num_heads * cfg.d_kv;
+ let q = Linear::new(cfg.d_model, inner_dim, vb.pp("q"))?;
+ let k = Linear::new(cfg.d_model, inner_dim, vb.pp("k"))?;
+ let v = Linear::new(cfg.d_model, inner_dim, vb.pp("v"))?;
+ let o = Linear::new(inner_dim, cfg.d_model, vb.pp("o"))?;
+ let relative_attention_bias = if has_relative_attention_bias {
+ let emb = Embedding::new(
+ cfg.relative_attention_num_buckets,
+ cfg.num_heads,
+ vb.pp("relative_attention_bias"),
+ )?;
+ Some(emb)
+ } else {
+ None
+ };
+ Ok(Self {
+ q,
+ k,
+ v,
+ o,
+ n_heads: cfg.num_heads,
+ d_kv: cfg.d_kv,
+ relative_attention_bias,
+ relative_attention_num_buckets: cfg.relative_attention_num_buckets,
+ relative_attention_max_distance: cfg.relative_attention_max_distance,
+ inner_dim,
+ use_cache: cfg.use_cache && decoder,
+ kv_cache: None,
+ span: tracing::span!(tracing::Level::TRACE, "attention"),
+ span_cache: tracing::span!(tracing::Level::TRACE, "attention-cache"),
+ span_mm: tracing::span!(tracing::Level::TRACE, "attention-mm"),
+ span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ position_bias: Option<&Tensor>,
+ key_value_states: Option<&Tensor>,
+ mask: Option<&Tensor>,
+ ) -> Result<(Tensor, Option<Tensor>)> {
+ // Performs Self-attention (if key_value_states is None) or attention
+ // over source sentence (provided by key_value_states).
+ let _enter = self.span.enter();
+ let kv_input = match key_value_states {
+ None => xs,
+ Some(key_value_states) => key_value_states,
+ };
+ let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?);
+ let kv_len = kv_input.dim(1)?;
+ let q = self.q.forward(xs)?;
+ let k = self.k.forward(kv_input)?;
+ let v = self.v.forward(kv_input)?;
+ let q = q
+ .reshape((b_sz, q_len, self.n_heads, self.d_kv))?
+ .transpose(1, 2)?
+ .contiguous()?;
+ let mut k = k
+ .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
+ .transpose(1, 2)?
+ .contiguous()?;
+ let mut v = v
+ .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
+ .transpose(1, 2)?
+ .contiguous()?;
+
+ if self.use_cache {
+ let _enter = self.span_cache.enter();
+ if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
+ k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?;
+ v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?;
+ };
+ self.kv_cache = Some((k.clone(), v.clone()));
+ };
+ // TODO: Use flash_attn.
+ let scores = {
+ let _enter = self.span_mm.enter();
+ q.matmul(&k.t()?)?
+ };
+ let scores = match mask {
+ None => scores,
+ Some(mask) => masked_fill(
+ &scores,
+ &mask
+ .unsqueeze(0)?
+ .unsqueeze(0)?
+ .repeat((b_sz, self.n_heads))?,
+ f32::NEG_INFINITY,
+ )?,
+ };
+
+ let (scores, position_bias) = match position_bias {
+ Some(position_bias) => (
+ scores.broadcast_add(position_bias)?,
+ Some(position_bias.clone()),
+ ),
+ None => match &self.relative_attention_bias {
+ None => (scores, None),
+ Some(relative_attention_bias) => {
+ // This only handles the bidirectional case.
+ let kv_len = k.dim(2)?;
+ let (q_start, q_end) = match self.use_cache {
+ true => ((kv_len - q_len) as u32, kv_len as u32),
+ false => (0_u32, kv_len as u32),
+ };
+ let num_buckets = self.relative_attention_num_buckets as u32 / 2;
+ let max_exact = num_buckets / 2;
+ let relative_position = (q_start..q_end)
+ .map(|i| {
+ (0..kv_len as u32)
+ .map(|j| {
+ if i < j {
+ if j - i < max_exact {
+ j - i + num_buckets
+ } else {
+ let b = f32::log(
+ (j - i) as f32 / max_exact as f32,
+ self.relative_attention_max_distance as f32
+ / max_exact as f32,
+ ) * (num_buckets - max_exact) as f32;
+ u32::min(
+ max_exact + num_buckets + b as u32,
+ self.relative_attention_num_buckets as u32 - 1,
+ )
+ }
+ } else if i - j < max_exact {
+ i - j
+ } else {
+ let b = f32::log(
+ (i - j) as f32 / max_exact as f32,
+ self.relative_attention_max_distance as f32
+ / max_exact as f32,
+ ) * (num_buckets - max_exact) as f32;
+ max_exact + b as u32
+ }
+ })
+ .collect::<Vec<u32>>()
+ })
+ .collect::<Vec<Vec<_>>>();
+ let relative_buckets = Tensor::new(relative_position, q.device())?;
+ let position_bias = relative_attention_bias
+ .forward(&relative_buckets)?
+ .permute((2, 0, 1))?
+ .unsqueeze(0)?;
+ (scores.broadcast_add(&position_bias)?, Some(position_bias))
+ // TODO: position_bias_masked?
+ }
+ },
+ };
+
+ let attn_weights = {
+ let _enter = self.span_sm.enter();
+ candle_nn::ops::softmax(&scores, D::Minus1)?
+ };
+ let attn_output = attn_weights.matmul(&v)?;
+ let attn_output = attn_output
+ .transpose(1, 2)?
+ .reshape((b_sz, q_len, self.inner_dim))?;
+ let attn_output = self.o.forward(&attn_output)?;
+ Ok((attn_output, position_bias))
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.kv_cache = None
+ }
+}
+
+#[derive(Debug)]
+struct T5LayerSelfAttention {
+ self_attention: T5Attention,
+ layer_norm: T5LayerNorm,
+ span: tracing::Span,
+}
+
+impl T5LayerSelfAttention {
+ fn load(h: bool, d: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?;
+ let layer_norm =
+ T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
+ Ok(Self {
+ self_attention,
+ layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "self-attn"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ position_bias: Option<&Tensor>,
+ mask: Option<&Tensor>,
+ ) -> Result<(Tensor, Option<Tensor>)> {
+ let _enter = self.span.enter();
+ let normed_xs = self.layer_norm.forward(xs)?;
+ let (ys, position_bias) =
+ self.self_attention
+ .forward(&normed_xs, position_bias, None, mask)?;
+ let ys = (xs + ys)?;
+ Ok((ys, position_bias))
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.self_attention.clear_kv_cache()
+ }
+}
+
+#[derive(Debug)]
+struct T5LayerCrossAttention {
+ cross_attention: T5Attention,
+ layer_norm: T5LayerNorm,
+ span: tracing::Span,
+}
+
+impl T5LayerCrossAttention {
+ fn load(decoder: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?;
+ let layer_norm =
+ T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
+ Ok(Self {
+ cross_attention,
+ layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "cross-attn"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ hidden_states: &Tensor,
+ position_bias: Option<&Tensor>,
+ key_value_states: &Tensor,
+ ) -> Result<(Tensor, Option<Tensor>)> {
+ let _enter = self.span.enter();
+ let normed_hidden_states = self.layer_norm.forward(hidden_states)?;
+ let (ys, position_bias) = self.cross_attention.forward(
+ &normed_hidden_states,
+ position_bias,
+ Some(key_value_states),
+ None,
+ )?;
+ let ys = (hidden_states + ys)?;
+ Ok((ys, position_bias))
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.cross_attention.clear_kv_cache()
+ }
+}
+
+#[derive(Debug)]
+struct T5Block {
+ self_attn: T5LayerSelfAttention,
+ cross_attn: Option<T5LayerCrossAttention>,
+ ff: T5LayerFF,
+ span: tracing::Span,
+}
+
+impl T5Block {
+ fn load(
+ has_relative_attention_bias: bool,
+ decoder: bool,
+ vb: VarBuilder,
+ cfg: &Config,
+ ) -> Result<Self> {
+ let vb = vb.pp("layer");
+ let self_attn =
+ T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?;
+ let cross_attn = if cfg.is_decoder {
+ Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?)
+ } else {
+ None
+ };
+ let ff_i = if cross_attn.is_some() { 2 } else { 1 };
+ let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?;
+ Ok(Self {
+ self_attn,
+ cross_attn,
+ ff,
+ span: tracing::span!(tracing::Level::TRACE, "block"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ position_bias: Option<&Tensor>,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<(Tensor, Option<Tensor>)> {
+ let _enter = self.span.enter();
+ // TODO: Cache masks
+ let mask = match self.cross_attn.is_some() {
+ true => {
+ let mask_len = xs.dim(1)?;
+ // If the input seq length is 1, no need for a mask, this is also helpful to avoid shape
+ // issues when using the KV cache in the decoder.
+ if mask_len <= 1 {
+ None
+ } else {
+ Some(get_mask(mask_len, xs.device())?)
+ }
+ }
+ false => None,
+ };
+ let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?;
+ // TODO: clamp for f16?
+ if let Some(cross_attn) = &mut self.cross_attn {
+ (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?;
+ // TODO: clamp for f16?
+ }
+ let xs = self.ff.forward(&xs)?;
+ // TODO: clamp for f16?
+ Ok((xs, position_bias))
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.self_attn.clear_kv_cache();
+ self.cross_attn.iter_mut().for_each(|c| c.clear_kv_cache());
+ }
+}
+
+#[derive(Debug)]
+struct T5Stack {
+ block: Vec<T5Block>,
+ shared: Arc<Embedding>,
+ final_layer_norm: T5LayerNorm,
+ span: tracing::Span,
+}
+
+impl T5Stack {
+ fn load(decoder: bool, vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
+ let block = (0..cfg.num_layers)
+ .map(|i| T5Block::load(i == 0, decoder, vb.pp(&format!("block.{i}")), cfg))
+ .collect::<Result<Vec<_>>>()?;
+ let final_layer_norm = T5LayerNorm::load(
+ cfg.d_model,
+ cfg.layer_norm_epsilon,
+ vb.pp("final_layer_norm"),
+ )?;
+ Ok(Self {
+ block,
+ shared: shared.clone(),
+ final_layer_norm,
+ span: tracing::span!(tracing::Level::TRACE, "stack"),
+ })
+ }
+
+ fn forward(
+ &mut self,
+ input_ids: &Tensor,
+ encoder_hidden_states: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let input_embeds = self.shared.as_ref().forward(input_ids)?;
+ let mut hidden_states = input_embeds;
+ let mut position_bias = None;
+ for block in self.block.iter_mut() {
+ (hidden_states, position_bias) = block.forward(
+ &hidden_states,
+ position_bias.as_ref(),
+ encoder_hidden_states,
+ )?
+ }
+ self.final_layer_norm.forward(&hidden_states)
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.block.iter_mut().for_each(|b| b.clear_kv_cache())
+ }
+}
+
+#[derive(Debug)]
+pub struct T5EncoderModel {
+ encoder: T5Stack,
+ device: Device,
+ span: tracing::Span,
+}
+
+impl T5EncoderModel {
+ pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
+ let shared = Arc::new(shared);
+ let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
+ Ok(Self {
+ encoder,
+ device: vb.device().clone(),
+ span: tracing::span!(tracing::Level::TRACE, "encoder"),
+ })
+ }
+
+ pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.encoder.forward(input_ids, None)
+ }
+
+ pub fn device(&self) -> &Device {
+ &self.device
+ }
+
+ pub fn clear_kv_cache(&mut self) {
+ self.encoder.clear_kv_cache()
+ }
+}
+
+#[derive(Debug)]
+pub struct T5ForConditionalGeneration {
+ encoder: T5Stack,
+ decoder: T5Stack,
+ d_model: usize,
+ tie_word_embeddings: bool,
+ lm_head: Option<Linear>,
+ shared: Arc<Embedding>,
+ device: Device,
+ span_decode: tracing::Span,
+ span_decode_head: tracing::Span,
+}
+
+impl T5ForConditionalGeneration {
+ pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ assert!(cfg.is_encoder_decoder);
+ let d_model = cfg.d_model;
+ let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
+ let shared = Arc::new(shared);
+
+ let mut encoder_cfg = cfg.clone();
+ encoder_cfg.is_decoder = false;
+ encoder_cfg.use_cache = false;
+ encoder_cfg.is_encoder_decoder = false;
+ let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, &encoder_cfg)?;
+
+ let mut decoder_cfg = cfg.clone();
+ decoder_cfg.is_decoder = true;
+ decoder_cfg.is_encoder_decoder = false;
+ decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
+ let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?;
+
+ let tie_word_embeddings = cfg.tie_word_embeddings;
+ let lm_head = if tie_word_embeddings {
+ None
+ } else {
+ Some(Linear::new(cfg.d_model, cfg.vocab_size, vb.pp("lm_head"))?)
+ };
+
+ Ok(Self {
+ encoder,
+ decoder,
+ d_model,
+ tie_word_embeddings,
+ lm_head,
+ shared,
+ device: vb.device().clone(),
+ span_decode: tracing::span!(tracing::Level::TRACE, "decode"),
+ span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"),
+ })
+ }
+
+ pub fn encode(&mut self, input_ids: &Tensor) -> Result<Tensor> {
+ self.encoder.forward(input_ids, None)
+ }
+
+ pub fn decode(
+ &mut self,
+ decoder_input_ids: &Tensor,
+ encoder_output: &Tensor,
+ ) -> Result<Tensor> {
+ let _enter = self.span_decode.enter();
+ let decoder_output = self
+ .decoder
+ .forward(decoder_input_ids, Some(encoder_output))?;
+
+ let scaling_factor = if self.tie_word_embeddings {
+ // Rescale output before projecting on vocab
+ // See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
+ (self.d_model as f64).sqrt()
+ } else {
+ 1.0
+ };
+ let sequence_output = ((decoder_output
+ .narrow(1, decoder_output.dim(1)? - 1, 1)?
+ .squeeze(1)?)
+ * scaling_factor)?;
+ let output = {
+ let _enter = self.span_decode_head.enter();
+ match self.lm_head {
+ None => sequence_output.matmul(&self.shared.embeddings().t()?)?,
+ Some(ref lm_head) => lm_head.forward(&sequence_output)?,
+ }
+ };
+
+ // TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5)
+ Ok(output)
+ }
+
+ pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
+ let encoder_output = self.encode(input_ids)?;
+ self.decode(decoder_input_ids, &encoder_output)
+ }
+
+ pub fn device(&self) -> &Device {
+ &self.device
+ }
+
+ pub fn clear_kv_cache(&mut self) {
+ self.encoder.clear_kv_cache();
+ self.decoder.clear_kv_cache();
+ }
+}
diff --git a/candle-transformers/src/models/whisper/audio.rs b/candle-transformers/src/models/whisper/audio.rs
new file mode 100644
index 00000000..4e01de32
--- /dev/null
+++ b/candle-transformers/src/models/whisper/audio.rs
@@ -0,0 +1,210 @@
+// Audio processing code, adapted from whisper.cpp
+// https://github.com/ggerganov/whisper.cpp
+
+pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {}
+
+impl Float for f32 {}
+impl Float for f64 {}
+
+// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357
+fn fft<T: Float>(inp: &[T]) -> Vec<T> {
+ let n = inp.len();
+ let zero = T::zero();
+ if n == 1 {
+ return vec![inp[0], zero];
+ }
+ if n % 2 == 1 {
+ return dft(inp);
+ }
+ let mut out = vec![zero; n * 2];
+
+ let mut even = Vec::with_capacity(n / 2);
+ let mut odd = Vec::with_capacity(n / 2);
+
+ for (i, &inp) in inp.iter().enumerate() {
+ if i % 2 == 0 {
+ even.push(inp)
+ } else {
+ odd.push(inp);
+ }
+ }
+
+ let even_fft = fft(&even);
+ let odd_fft = fft(&odd);
+
+ let two_pi = T::PI() + T::PI();
+ let n_t = T::from(n).unwrap();
+ for k in 0..n / 2 {
+ let k_t = T::from(k).unwrap();
+ let theta = two_pi * k_t / n_t;
+ let re = theta.cos();
+ let im = -theta.sin();
+
+ let re_odd = odd_fft[2 * k];
+ let im_odd = odd_fft[2 * k + 1];
+
+ out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd;
+ out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd;
+
+ out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd;
+ out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd;
+ }
+ out
+}
+
+// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337
+fn dft<T: Float>(inp: &[T]) -> Vec<T> {
+ let zero = T::zero();
+ let n = inp.len();
+ let two_pi = T::PI() + T::PI();
+
+ let mut out = Vec::new();
+ out.reserve(2 * n);
+ let n_t = T::from(n).unwrap();
+ for k in 0..n {
+ let k_t = T::from(k).unwrap();
+ let mut re = zero;
+ let mut im = zero;
+
+ for (j, &inp) in inp.iter().enumerate() {
+ let j_t = T::from(j).unwrap();
+ let angle = two_pi * k_t * j_t / n_t;
+ re += inp * angle.cos();
+ im -= inp * angle.sin();
+ }
+
+ out.push(re);
+ out.push(im);
+ }
+ out
+}
+
+#[allow(clippy::too_many_arguments)]
+// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2414
+fn log_mel_spectrogram_w<T: Float>(
+ ith: usize,
+ hann: &[T],
+ samples: &[T],
+ filters: &[T],
+ fft_size: usize,
+ fft_step: usize,
+ speed_up: bool,
+ n_len: usize,
+ n_mel: usize,
+ n_threads: usize,
+) -> Vec<T> {
+ let n_fft = if speed_up {
+ 1 + fft_size / 4
+ } else {
+ 1 + fft_size / 2
+ };
+
+ let zero = T::zero();
+ let half = T::from(0.5).unwrap();
+ let mut fft_in = vec![zero; fft_size];
+ let mut mel = vec![zero; n_len * n_mel];
+
+ for i in (ith..n_len).step_by(n_threads) {
+ let offset = i * fft_step;
+
+ // apply Hanning window
+ for j in 0..fft_size {
+ fft_in[j] = if offset + j < samples.len() {
+ hann[j] * samples[offset + j]
+ } else {
+ zero
+ }
+ }
+
+ // FFT -> mag^2
+ let mut fft_out: Vec<T> = fft(&fft_in);
+
+ for j in 0..fft_size {
+ fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1];
+ }
+ for j in 1..fft_size / 2 {
+ let v = fft_out[fft_size - j];
+ fft_out[j] += v;
+ }
+
+ if speed_up {
+ // scale down in the frequency domain results in a speed up in the time domain
+ for j in 0..n_fft {
+ fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]);
+ }
+ }
+
+ // mel spectrogram
+ for j in 0..n_mel {
+ let mut sum = zero;
+ for k in 0..n_fft {
+ sum += fft_out[k] * filters[j * n_fft + k];
+ }
+ mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10();
+ }
+ }
+ mel
+}
+
+fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
+ samples: &[T],
+ filters: &[T],
+ fft_size: usize,
+ fft_step: usize,
+ n_mel: usize,
+ speed_up: bool,
+) -> Vec<T> {
+ let zero = T::zero();
+ let two_pi = T::PI() + T::PI();
+ let half = T::from(0.5).unwrap();
+ let one = T::from(1.0).unwrap();
+ let four = T::from(4.0).unwrap();
+ let fft_size_t = T::from(fft_size).unwrap();
+
+ let hann: Vec<T> = (0..fft_size)
+ .map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos()))
+ .collect();
+ let n_len = samples.len() / fft_step;
+
+ // pad audio with at least one extra chunk of zeros
+ let pad = 100 * super::CHUNK_LENGTH / 2;
+ let n_len = if n_len % pad != 0 {
+ (n_len / pad + 1) * pad
+ } else {
+ n_len
+ };
+ let n_len = n_len + pad;
+ let samples = {
+ let mut samples_padded = samples.to_vec();
+ let to_add = n_len * fft_step - samples.len();
+ samples_padded.extend(std::iter::repeat(zero).take(to_add));
+ samples_padded
+ };
+
+ // Use a single thread for now.
+ let mut mel = log_mel_spectrogram_w(
+ 0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1,
+ );
+ let mmax = mel
+ .iter()
+ .max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater))
+ .copied()
+ .unwrap_or(zero)
+ - T::from(8).unwrap();
+ for m in mel.iter_mut() {
+ let v = T::max(*m, mmax);
+ *m = v / four + one
+ }
+ mel
+}
+
+pub fn pcm_to_mel<T: Float + std::fmt::Display>(samples: &[T], filters: &[T]) -> Vec<T> {
+ log_mel_spectrogram_(
+ samples,
+ filters,
+ super::N_FFT,
+ super::HOP_LENGTH,
+ super::N_MELS,
+ false,
+ )
+}
diff --git a/candle-transformers/src/models/whisper/mod.rs b/candle-transformers/src/models/whisper/mod.rs
new file mode 100644
index 00000000..7dc8107b
--- /dev/null
+++ b/candle-transformers/src/models/whisper/mod.rs
@@ -0,0 +1,26 @@
+pub mod audio;
+pub mod model;
+
+pub const DTYPE: candle::DType = candle::DType::F32;
+
+// Audio parameters.
+pub const SAMPLE_RATE: usize = 16000;
+pub const N_FFT: usize = 400;
+pub const N_MELS: usize = 80;
+pub const HOP_LENGTH: usize = 160;
+pub const CHUNK_LENGTH: usize = 30;
+pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
+pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
+
+pub const NO_SPEECH_THRESHOLD: f64 = 0.6;
+pub const LOGPROB_THRESHOLD: f64 = -1.0;
+pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
+pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
+
+// Tokenizer dependent bits.
+pub const SOT_TOKEN: &str = "<|startoftranscript|>";
+pub const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
+pub const TRANSLATE_TOKEN: &str = "<|translate|>";
+pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
+pub const EOT_TOKEN: &str = "<|endoftext|>";
+pub const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs
new file mode 100644
index 00000000..d2eda796
--- /dev/null
+++ b/candle-transformers/src/models/whisper/model.rs
@@ -0,0 +1,416 @@
+use candle::{Device, IndexOp, Result, Tensor, D};
+use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
+use serde::Deserialize;
+
+// The names in comments correspond to the original implementation:
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
+#[derive(Debug, Clone, PartialEq, Deserialize)]
+pub struct Config {
+ pub num_mel_bins: usize, // n_mels
+ pub max_source_positions: usize, // n_audio_ctx
+ pub d_model: usize, // n_audio_state
+ pub encoder_attention_heads: usize, // n_audio_head
+ pub encoder_layers: usize, // n_audio_layer
+ pub vocab_size: usize, // n_vocab
+ pub max_target_positions: usize, // n_text_ctx
+ // pub n_text_state: usize,
+ pub decoder_attention_heads: usize, // n_text_head
+ pub decoder_layers: usize, // n_text_layer
+ pub suppress_tokens: Vec<u32>,
+}
+
+fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
+ let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
+ Ok(Embedding::new(embeddings, hidden_size))
+}
+//
+// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
+// model.
+#[derive(Debug)]
+pub struct Linear {
+ inner: candle_nn::Linear,
+ span: tracing::Span,
+}
+
+impl Linear {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
+
+fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ let inner = candle_nn::linear(size1, size2, vb)?;
+ Ok(Linear { inner, span })
+}
+
+fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
+ let span = tracing::span!(tracing::Level::TRACE, "linear");
+ let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
+ Ok(Linear { inner, span })
+}
+
+fn conv1d(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ config: Conv1dConfig,
+ vb: VarBuilder,
+) -> Result<Conv1d> {
+ let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
+ let bias = vb.get(out_channels, "bias")?;
+ Ok(Conv1d::new(weight, Some(bias), config))
+}
+
+fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
+ let weight = vb.get(size, "weight")?;
+ let bias = vb.get(size, "bias")?;
+ Ok(LayerNorm::new(weight, bias, 1e-5))
+}
+
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
+struct MultiHeadAttention {
+ query: Linear,
+ key: Linear,
+ value: Linear,
+ out: Linear,
+ n_head: usize,
+ span: tracing::Span,
+ softmax_span: tracing::Span,
+ matmul_span: tracing::Span,
+ kv_cache: Option<(Tensor, Tensor)>,
+}
+
+impl MultiHeadAttention {
+ fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
+ let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax");
+ let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul");
+ let query = linear(n_state, n_state, vb.pp("q_proj"))?;
+ let value = linear(n_state, n_state, vb.pp("v_proj"))?;
+ let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
+ let out = linear(n_state, n_state, vb.pp("out_proj"))?;
+ Ok(Self {
+ query,
+ key,
+ value,
+ out,
+ n_head,
+ span,
+ softmax_span,
+ matmul_span,
+ kv_cache: None,
+ })
+ }
+
+ fn forward(
+ &mut self,
+ x: &Tensor,
+ xa: Option<&Tensor>,
+ mask: Option<&Tensor>,
+ flush_cache: bool,
+ ) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let q = self.query.forward(x)?;
+ let (k, v) = match xa {
+ None => {
+ let k = self.key.forward(x)?;
+ let v = self.value.forward(x)?;
+ (k, v)
+ }
+ Some(x) => {
+ if flush_cache {
+ self.kv_cache = None;
+ }
+ if let Some((k, v)) = &self.kv_cache {
+ (k.clone(), v.clone())
+ } else {
+ let k = self.key.forward(x)?;
+ let v = self.value.forward(x)?;
+ self.kv_cache = Some((k.clone(), v.clone()));
+ (k, v)
+ }
+ }
+ };
+ let wv = self.qkv_attention(&q, &k, &v, mask)?;
+ let out = self.out.forward(&wv)?;
+ Ok(out)
+ }
+
+ fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
+ let (n_batch, n_ctx, n_state) = x.dims3()?;
+ let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
+ x.reshape(target_dims)?.transpose(1, 2)
+ }
+
+ fn qkv_attention(
+ &self,
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let (_, n_ctx, n_state) = q.dims3()?;
+ let scale = ((n_state / self.n_head) as f64).powf(-0.25);
+ let q = (self.reshape_head(q)? * scale)?;
+ let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
+ let v = self.reshape_head(v)?.contiguous()?;
+ let mut qk = {
+ let _enter = self.matmul_span.enter();
+ q.matmul(&k)?
+ };
+ if let Some(mask) = mask {
+ let mask = mask.i((0..n_ctx, 0..n_ctx))?;
+ qk = qk.broadcast_add(&mask)?
+ }
+ let w = {
+ let _enter = self.softmax_span.enter();
+ candle_nn::ops::softmax_last_dim(&qk)?
+ };
+ let wv = {
+ let _enter = self.matmul_span.enter();
+ w.matmul(&v)?
+ }
+ .transpose(1, 2)?
+ .flatten_from(2)?;
+ Ok(wv)
+ }
+}
+
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
+struct ResidualAttentionBlock {
+ attn: MultiHeadAttention,
+ attn_ln: LayerNorm,
+ cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
+ mlp_linear1: Linear,
+ mlp_linear2: Linear,
+ mlp_ln: LayerNorm,
+ span: tracing::Span,
+}
+
+impl ResidualAttentionBlock {
+ fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
+ let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
+ let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
+ let cross_attn = if ca {
+ let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
+ let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
+ Some((cross_attn, cross_attn_ln))
+ } else {
+ None
+ };
+ let n_mlp = n_state * 4;
+ let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
+ let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
+ let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
+ Ok(Self {
+ attn,
+ attn_ln,
+ cross_attn,
+ mlp_linear1,
+ mlp_linear2,
+ mlp_ln,
+ span,
+ })
+ }
+
+ fn forward(
+ &mut self,
+ x: &Tensor,
+ xa: Option<&Tensor>,
+ mask: Option<&Tensor>,
+ flush_kv_cache: bool,
+ ) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let attn = self
+ .attn
+ .forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
+ let mut x = (x + attn)?;
+ if let Some((attn, ln)) = &mut self.cross_attn {
+ x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
+ }
+ let mlp = self.mlp_linear2.forward(
+ &self
+ .mlp_linear1
+ .forward(&self.mlp_ln.forward(&x)?)?
+ .gelu()?,
+ )?;
+ x + mlp
+ }
+}
+
+fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
+ let max_timescale = 10000f32;
+ let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
+ let inv_timescales: Vec<_> = (0..channels / 2)
+ .map(|i| (i as f32 * (-log_timescale_increment)).exp())
+ .collect();
+ let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
+ let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
+ .to_dtype(candle::DType::F32)?
+ .unsqueeze(1)?;
+ let sh = (length, channels / 2);
+ let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?;
+ let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;
+ Ok(sincos)
+}
+
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
+pub struct AudioEncoder {
+ conv1: Conv1d,
+ conv2: Conv1d,
+ positional_embedding: Tensor,
+ blocks: Vec<ResidualAttentionBlock>,
+ ln_post: LayerNorm,
+ span: tracing::Span,
+ conv1_span: tracing::Span,
+ conv2_span: tracing::Span,
+}
+
+impl AudioEncoder {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "audio-encoder");
+ let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1");
+ let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2");
+ let n_state = cfg.d_model;
+ let n_head = cfg.encoder_attention_heads;
+ let n_ctx = cfg.max_source_positions;
+ let cfg1 = Conv1dConfig {
+ padding: 1,
+ stride: 1,
+ groups: 1,
+ dilation: 1,
+ };
+ let cfg2 = Conv1dConfig {
+ padding: 1,
+ stride: 2,
+ groups: 1,
+ dilation: 1,
+ };
+ let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
+ let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
+ let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
+ let blocks = (0..cfg.encoder_layers)
+ .map(|i| {
+ ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}")))
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
+ Ok(Self {
+ conv1,
+ conv2,
+ positional_embedding,
+ blocks,
+ ln_post,
+ conv1_span,
+ conv2_span,
+ span,
+ })
+ }
+
+ pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let x = {
+ let _enter = self.conv1_span.enter();
+ self.conv1.forward(x)?.gelu()?
+ };
+ let x = {
+ let _enter = self.conv2_span.enter();
+ self.conv2.forward(&x)?.gelu()?
+ };
+ let x = x.transpose(1, 2)?;
+ let (_bsize, seq_len, _hidden) = x.dims3()?;
+ let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
+ let mut x = x.broadcast_add(&positional_embedding)?;
+ for block in self.blocks.iter_mut() {
+ x = block.forward(&x, None, None, flush_kv_cache)?
+ }
+ let x = self.ln_post.forward(&x)?;
+ Ok(x)
+ }
+}
+
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
+pub struct TextDecoder {
+ token_embedding: Embedding,
+ positional_embedding: Tensor,
+ blocks: Vec<ResidualAttentionBlock>,
+ ln: LayerNorm,
+ mask: Tensor,
+ span: tracing::Span,
+ span_final: tracing::Span,
+}
+
+impl TextDecoder {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
+ let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final");
+ let n_state = cfg.d_model;
+ let n_head = cfg.decoder_attention_heads;
+ let n_ctx = cfg.max_target_positions;
+ let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?;
+ let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?;
+ let blocks = (0..cfg.decoder_layers)
+ .map(|i| {
+ ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}")))
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
+ let mask: Vec<_> = (0..n_ctx)
+ .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
+ .collect();
+ let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
+ Ok(Self {
+ token_embedding,
+ positional_embedding,
+ blocks,
+ ln,
+ mask,
+ span,
+ span_final,
+ })
+ }
+
+ pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let last = x.dim(D::Minus1)?;
+ let token_embedding = self.token_embedding.forward(x)?;
+ let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
+ let mut x = token_embedding.broadcast_add(&positional_embedding)?;
+ for block in self.blocks.iter_mut() {
+ x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
+ }
+ self.ln.forward(&x)
+ }
+
+ pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> {
+ let b_size = x.dim(0)?;
+ let w = self.token_embedding.embeddings().broadcast_left(b_size)?;
+ let logits = {
+ let _enter = self.span_final.enter();
+ x.matmul(&w.t()?)?
+ };
+ Ok(logits)
+ }
+}
+
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
+pub struct Whisper {
+ pub encoder: AudioEncoder,
+ pub decoder: TextDecoder,
+ pub config: Config,
+}
+
+impl Whisper {
+ pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
+ let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?;
+ let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?;
+ Ok(Self {
+ encoder,
+ decoder,
+ config,
+ })
+ }
+}
diff --git a/candle-transformers/src/models/wuerstchen/attention_processor.rs b/candle-transformers/src/models/wuerstchen/attention_processor.rs
new file mode 100644
index 00000000..0b90cb9d
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/attention_processor.rs
@@ -0,0 +1,118 @@
+use candle::{Module, Result, Tensor};
+use candle_nn::{linear, Linear, VarBuilder};
+
+// A simplified version of:
+// https://github.com/huggingface/diffusers/blob/119ad2c3dc8a8fb8446a83f4bf6f20929487b47f/src/diffusers/models/attention_processor.py#L38
+#[derive(Debug)]
+pub struct Attention {
+ to_q: Linear,
+ to_k: Linear,
+ to_v: Linear,
+ to_out: Linear,
+ heads: usize,
+ scale: f64,
+ use_flash_attn: bool,
+}
+
+#[cfg(feature = "flash-attn")]
+fn flash_attn(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ softmax_scale: f32,
+ causal: bool,
+) -> Result<Tensor> {
+ candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
+}
+
+#[cfg(not(feature = "flash-attn"))]
+fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
+ unimplemented!("compile with '--features flash-attn'")
+}
+
+impl Attention {
+ pub fn new(
+ query_dim: usize,
+ heads: usize,
+ dim_head: usize,
+ use_flash_attn: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let inner_dim = dim_head * heads;
+ let scale = 1.0 / f64::sqrt(dim_head as f64);
+ let to_q = linear(query_dim, inner_dim, vb.pp("to_q"))?;
+ let to_k = linear(query_dim, inner_dim, vb.pp("to_k"))?;
+ let to_v = linear(query_dim, inner_dim, vb.pp("to_v"))?;
+ let to_out = linear(inner_dim, query_dim, vb.pp("to_out.0"))?;
+ Ok(Self {
+ to_q,
+ to_k,
+ to_v,
+ to_out,
+ scale,
+ heads,
+ use_flash_attn,
+ })
+ }
+
+ fn batch_to_head_dim(&self, xs: &Tensor) -> Result<Tensor> {
+ let (b_size, seq_len, dim) = xs.dims3()?;
+ xs.reshape((b_size / self.heads, self.heads, seq_len, dim))?
+ .permute((0, 2, 1, 3))?
+ .reshape((b_size / self.heads, seq_len, dim * self.heads))
+ }
+
+ fn head_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
+ let (b_size, seq_len, dim) = xs.dims3()?;
+ xs.reshape((b_size, seq_len, self.heads, dim / self.heads))?
+ .permute((0, 2, 1, 3))?
+ .reshape((b_size * self.heads, seq_len, dim / self.heads))
+ }
+
+ fn get_attention_scores(&self, query: &Tensor, key: &Tensor) -> Result<Tensor> {
+ let attn_probs = (query.matmul(&key.t()?)? * self.scale)?;
+ candle_nn::ops::softmax_last_dim(&attn_probs)
+ }
+
+ pub fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
+ let (b_size, channel, h, w) = xs.dims4()?;
+ let xs = xs.reshape((b_size, channel, h * w))?.t()?;
+
+ let query = self.to_q.forward(&xs)?;
+ let key = self.to_k.forward(encoder_hidden_states)?;
+ let value = self.to_v.forward(encoder_hidden_states)?;
+
+ let query = self.head_to_batch_dim(&query)?;
+ let key = self.head_to_batch_dim(&key)?;
+ let value = self.head_to_batch_dim(&value)?;
+
+ let xs = if self.use_flash_attn {
+ let init_dtype = query.dtype();
+ let q = query
+ .to_dtype(candle::DType::F16)?
+ .unsqueeze(0)?
+ .transpose(1, 2)?;
+ let k = key
+ .to_dtype(candle::DType::F16)?
+ .unsqueeze(0)?
+ .transpose(1, 2)?;
+ let v = value
+ .to_dtype(candle::DType::F16)?
+ .unsqueeze(0)?
+ .transpose(1, 2)?;
+ flash_attn(&q, &k, &v, self.scale as f32, false)?
+ .transpose(1, 2)?
+ .squeeze(0)?
+ .to_dtype(init_dtype)?
+ } else {
+ let attn_prs = self.get_attention_scores(&query, &key)?;
+ attn_prs.matmul(&value)?
+ };
+ let xs = self.batch_to_head_dim(&xs)?;
+
+ self.to_out
+ .forward(&xs)?
+ .t()?
+ .reshape((b_size, channel, h, w))
+ }
+}
diff --git a/candle-transformers/src/models/wuerstchen/common.rs b/candle-transformers/src/models/wuerstchen/common.rs
new file mode 100644
index 00000000..c89ec919
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/common.rs
@@ -0,0 +1,203 @@
+use candle::{DType, Module, Result, Tensor, D};
+use candle_nn::VarBuilder;
+
+// https://github.com/huggingface/diffusers/blob/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py#L22
+#[derive(Debug)]
+pub struct WLayerNorm {
+ eps: f64,
+}
+
+impl WLayerNorm {
+ pub fn new(_size: usize) -> Result<Self> {
+ Ok(Self { eps: 1e-6 })
+ }
+}
+
+impl Module for WLayerNorm {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = xs.permute((0, 2, 3, 1))?;
+
+ let x_dtype = xs.dtype();
+ let internal_dtype = match x_dtype {
+ DType::F16 | DType::BF16 => DType::F32,
+ d => d,
+ };
+
+ let hidden_size = xs.dim(D::Minus1)?;
+ let xs = xs.to_dtype(internal_dtype)?;
+ let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
+ let xs = xs.broadcast_sub(&mean_x)?;
+ let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
+ xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
+ .to_dtype(x_dtype)?
+ .permute((0, 3, 1, 2))
+ }
+}
+
+#[derive(Debug)]
+pub struct LayerNormNoWeights {
+ eps: f64,
+}
+
+impl LayerNormNoWeights {
+ pub fn new(_size: usize) -> Result<Self> {
+ Ok(Self { eps: 1e-6 })
+ }
+}
+
+impl Module for LayerNormNoWeights {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let x_dtype = xs.dtype();
+ let internal_dtype = match x_dtype {
+ DType::F16 | DType::BF16 => DType::F32,
+ d => d,
+ };
+ let hidden_size = xs.dim(D::Minus1)?;
+ let xs = xs.to_dtype(internal_dtype)?;
+ let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
+ let xs = xs.broadcast_sub(&mean_x)?;
+ let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
+ xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
+ .to_dtype(x_dtype)
+ }
+}
+
+#[derive(Debug)]
+pub struct TimestepBlock {
+ mapper: candle_nn::Linear,
+}
+
+impl TimestepBlock {
+ pub fn new(c: usize, c_timestep: usize, vb: VarBuilder) -> Result<Self> {
+ let mapper = candle_nn::linear(c_timestep, c * 2, vb.pp("mapper"))?;
+ Ok(Self { mapper })
+ }
+
+ pub fn forward(&self, xs: &Tensor, t: &Tensor) -> Result<Tensor> {
+ let ab = self
+ .mapper
+ .forward(t)?
+ .unsqueeze(2)?
+ .unsqueeze(3)?
+ .chunk(2, 1)?;
+ xs.broadcast_mul(&(&ab[0] + 1.)?)?.broadcast_add(&ab[1])
+ }
+}
+
+#[derive(Debug)]
+pub struct GlobalResponseNorm {
+ gamma: Tensor,
+ beta: Tensor,
+}
+
+impl GlobalResponseNorm {
+ pub fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
+ let gamma = vb.get((1, 1, 1, dim), "gamma")?;
+ let beta = vb.get((1, 1, 1, dim), "beta")?;
+ Ok(Self { gamma, beta })
+ }
+}
+
+impl Module for GlobalResponseNorm {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
+ let stand_div_norm =
+ agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?;
+ xs.broadcast_mul(&stand_div_norm)?
+ .broadcast_mul(&self.gamma)?
+ .broadcast_add(&self.beta)?
+ + xs
+ }
+}
+
+#[derive(Debug)]
+pub struct ResBlock {
+ depthwise: candle_nn::Conv2d,
+ norm: WLayerNorm,
+ channelwise_lin1: candle_nn::Linear,
+ channelwise_grn: GlobalResponseNorm,
+ channelwise_lin2: candle_nn::Linear,
+}
+
+impl ResBlock {
+ pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> {
+ let cfg = candle_nn::Conv2dConfig {
+ padding: ksize / 2,
+ groups: c,
+ ..Default::default()
+ };
+ let depthwise = candle_nn::conv2d(c + c_skip, c, ksize, cfg, vb.pp("depthwise"))?;
+ let norm = WLayerNorm::new(c)?;
+ let channelwise_lin1 = candle_nn::linear(c, c * 4, vb.pp("channelwise.0"))?;
+ let channelwise_grn = GlobalResponseNorm::new(c * 4, vb.pp("channelwise.2"))?;
+ let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
+ Ok(Self {
+ depthwise,
+ norm,
+ channelwise_lin1,
+ channelwise_grn,
+ channelwise_lin2,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> {
+ let x_res = xs;
+ let xs = match x_skip {
+ None => xs.clone(),
+ Some(x_skip) => Tensor::cat(&[xs, x_skip], 1)?,
+ };
+ let xs = xs
+ .apply(&self.depthwise)?
+ .apply(&self.norm)?
+ .permute((0, 2, 3, 1))?;
+ let xs = xs
+ .apply(&self.channelwise_lin1)?
+ .gelu_erf()?
+ .apply(&self.channelwise_grn)?
+ .apply(&self.channelwise_lin2)?
+ .permute((0, 3, 1, 2))?;
+ xs + x_res
+ }
+}
+use super::attention_processor::Attention;
+#[derive(Debug)]
+pub struct AttnBlock {
+ self_attn: bool,
+ norm: WLayerNorm,
+ attention: Attention,
+ kv_mapper_lin: candle_nn::Linear,
+}
+
+impl AttnBlock {
+ pub fn new(
+ c: usize,
+ c_cond: usize,
+ nhead: usize,
+ self_attn: bool,
+ use_flash_attn: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let norm = WLayerNorm::new(c)?;
+ let attention = Attention::new(c, nhead, c / nhead, use_flash_attn, vb.pp("attention"))?;
+ let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
+ Ok(Self {
+ self_attn,
+ norm,
+ attention,
+ kv_mapper_lin,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, kv: &Tensor) -> Result<Tensor> {
+ let kv = candle_nn::ops::silu(kv)?.apply(&self.kv_mapper_lin)?;
+ let norm_xs = self.norm.forward(xs)?;
+ let kv = if self.self_attn {
+ let (b_size, channel, _, _) = xs.dims4()?;
+ let norm_xs = norm_xs.reshape((b_size, channel, ()))?.transpose(1, 2)?;
+ Tensor::cat(&[&norm_xs, &kv], 1)?.contiguous()?
+ } else {
+ kv
+ };
+ xs + self.attention.forward(&norm_xs, &kv)
+ }
+}
diff --git a/candle-transformers/src/models/wuerstchen/ddpm.rs b/candle-transformers/src/models/wuerstchen/ddpm.rs
new file mode 100644
index 00000000..9e69b868
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/ddpm.rs
@@ -0,0 +1,103 @@
+use candle::{Result, Tensor};
+
+#[derive(Debug, Clone)]
+pub struct DDPMWSchedulerConfig {
+ scaler: f64,
+ s: f64,
+}
+
+impl Default for DDPMWSchedulerConfig {
+ fn default() -> Self {
+ Self {
+ scaler: 1f64,
+ s: 0.008f64,
+ }
+ }
+}
+
+pub struct DDPMWScheduler {
+ init_alpha_cumprod: f64,
+ init_noise_sigma: f64,
+ timesteps: Vec<f64>,
+ pub config: DDPMWSchedulerConfig,
+}
+
+impl DDPMWScheduler {
+ pub fn new(inference_steps: usize, config: DDPMWSchedulerConfig) -> Result<Self> {
+ let init_alpha_cumprod = (config.s / (1. + config.s) * std::f64::consts::PI)
+ .cos()
+ .powi(2);
+ let timesteps = (0..=inference_steps)
+ .map(|i| 1. - i as f64 / inference_steps as f64)
+ .collect::<Vec<_>>();
+ Ok(Self {
+ init_alpha_cumprod,
+ init_noise_sigma: 1.0,
+ timesteps,
+ config,
+ })
+ }
+
+ pub fn timesteps(&self) -> &[f64] {
+ &self.timesteps
+ }
+
+ fn alpha_cumprod(&self, t: f64) -> f64 {
+ let scaler = self.config.scaler;
+ let s = self.config.s;
+ let t = if scaler > 1. {
+ 1. - (1. - t).powf(scaler)
+ } else if scaler < 1. {
+ t.powf(scaler)
+ } else {
+ t
+ };
+ let alpha_cumprod = ((t + s) / (1. + s) * std::f64::consts::PI * 0.5)
+ .cos()
+ .powi(2)
+ / self.init_alpha_cumprod;
+ alpha_cumprod.clamp(0.0001, 0.9999)
+ }
+
+ fn previous_timestep(&self, ts: f64) -> f64 {
+ let index = self
+ .timesteps
+ .iter()
+ .enumerate()
+ .map(|(idx, v)| (idx, (v - ts).abs()))
+ .min_by(|x, y| x.1.total_cmp(&y.1))
+ .unwrap()
+ .0;
+ self.timesteps[index + 1]
+ }
+
+ /// Ensures interchangeability with schedulers that need to scale the denoising model input
+ /// depending on the current timestep.
+ pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
+ sample
+ }
+
+ pub fn step(&self, model_output: &Tensor, ts: f64, sample: &Tensor) -> Result<Tensor> {
+ let prev_t = self.previous_timestep(ts);
+
+ let alpha_cumprod = self.alpha_cumprod(ts);
+ let alpha_cumprod_prev = self.alpha_cumprod(prev_t);
+ let alpha = alpha_cumprod / alpha_cumprod_prev;
+
+ let mu = (sample - model_output * ((1. - alpha) / (1. - alpha_cumprod).sqrt()))?;
+ let mu = (mu * (1. / alpha).sqrt())?;
+
+ let std_noise = mu.randn_like(0., 1.)?;
+ let std =
+ std_noise * ((1. - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt();
+ if prev_t == 0. {
+ Ok(mu)
+ } else {
+ mu + std
+ }
+ }
+
+ pub fn init_noise_sigma(&self) -> f64 {
+ self.init_noise_sigma
+ }
+}
diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs
new file mode 100644
index 00000000..64a48c8a
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/diffnext.rs
@@ -0,0 +1,396 @@
+use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm};
+use candle::{DType, Module, Result, Tensor, D};
+use candle_nn::VarBuilder;
+
+#[derive(Debug)]
+pub struct ResBlockStageB {
+ depthwise: candle_nn::Conv2d,
+ norm: WLayerNorm,
+ channelwise_lin1: candle_nn::Linear,
+ channelwise_grn: GlobalResponseNorm,
+ channelwise_lin2: candle_nn::Linear,
+}
+
+impl ResBlockStageB {
+ pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> {
+ let cfg = candle_nn::Conv2dConfig {
+ groups: c,
+ padding: ksize / 2,
+ ..Default::default()
+ };
+ let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?;
+ let norm = WLayerNorm::new(c)?;
+ let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?;
+ let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?;
+ let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
+ Ok(Self {
+ depthwise,
+ norm,
+ channelwise_lin1,
+ channelwise_grn,
+ channelwise_lin2,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> {
+ let x_res = xs;
+ let xs = xs.apply(&self.depthwise)?.apply(&self.norm)?;
+ let xs = match x_skip {
+ None => xs.clone(),
+ Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?,
+ };
+ let xs = xs
+ .permute((0, 2, 3, 1))?
+ .contiguous()?
+ .apply(&self.channelwise_lin1)?
+ .gelu()?
+ .apply(&self.channelwise_grn)?
+ .apply(&self.channelwise_lin2)?
+ .permute((0, 3, 1, 2))?;
+ xs + x_res
+ }
+}
+
+#[derive(Debug)]
+struct SubBlock {
+ res_block: ResBlockStageB,
+ ts_block: TimestepBlock,
+ attn_block: Option<AttnBlock>,
+}
+
+#[derive(Debug)]
+struct DownBlock {
+ layer_norm: Option<WLayerNorm>,
+ conv: Option<candle_nn::Conv2d>,
+ sub_blocks: Vec<SubBlock>,
+}
+
+#[derive(Debug)]
+struct UpBlock {
+ sub_blocks: Vec<SubBlock>,
+ layer_norm: Option<WLayerNorm>,
+ conv: Option<candle_nn::ConvTranspose2d>,
+}
+
+#[derive(Debug)]
+pub struct WDiffNeXt {
+ clip_mapper: candle_nn::Linear,
+ effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
+ seq_norm: LayerNormNoWeights,
+ embedding_conv: candle_nn::Conv2d,
+ embedding_ln: WLayerNorm,
+ down_blocks: Vec<DownBlock>,
+ up_blocks: Vec<UpBlock>,
+ clf_ln: WLayerNorm,
+ clf_conv: candle_nn::Conv2d,
+ c_r: usize,
+ patch_size: usize,
+}
+
+impl WDiffNeXt {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ c_in: usize,
+ c_out: usize,
+ c_r: usize,
+ c_cond: usize,
+ clip_embd: usize,
+ patch_size: usize,
+ use_flash_attn: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
+ const BLOCKS: [usize; 4] = [4, 4, 14, 4];
+ const NHEAD: [usize; 4] = [1, 10, 20, 20];
+ const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
+ const EFFNET_EMBD: usize = 16;
+
+ let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
+ let mut effnet_mappers = Vec::with_capacity(2 * INJECT_EFFNET.len());
+ let vb_e = vb.pp("effnet_mappers");
+ for (i, &inject) in INJECT_EFFNET.iter().enumerate() {
+ let c = if inject {
+ Some(candle_nn::conv2d(
+ EFFNET_EMBD,
+ c_cond,
+ 1,
+ Default::default(),
+ vb_e.pp(i),
+ )?)
+ } else {
+ None
+ };
+ effnet_mappers.push(c)
+ }
+ for (i, &inject) in INJECT_EFFNET.iter().rev().enumerate() {
+ let c = if inject {
+ Some(candle_nn::conv2d(
+ EFFNET_EMBD,
+ c_cond,
+ 1,
+ Default::default(),
+ vb_e.pp(i + INJECT_EFFNET.len()),
+ )?)
+ } else {
+ None
+ };
+ effnet_mappers.push(c)
+ }
+ let seq_norm = LayerNormNoWeights::new(c_cond)?;
+ let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
+ let embedding_conv = candle_nn::conv2d(
+ c_in * patch_size * patch_size,
+ C_HIDDEN[0],
+ 1,
+ Default::default(),
+ vb.pp("embedding.1"),
+ )?;
+
+ let mut down_blocks = Vec::with_capacity(C_HIDDEN.len());
+ for (i, &c_hidden) in C_HIDDEN.iter().enumerate() {
+ let vb = vb.pp("down_blocks").pp(i);
+ let (layer_norm, conv, start_layer_i) = if i > 0 {
+ let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
+ let cfg = candle_nn::Conv2dConfig {
+ stride: 2,
+ ..Default::default()
+ };
+ let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp("0.1"))?;
+ (Some(layer_norm), Some(conv), 1)
+ } else {
+ (None, None, 0)
+ };
+ let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
+ let mut layer_i = start_layer_i;
+ for _j in 0..BLOCKS[i] {
+ let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
+ let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?;
+ layer_i += 1;
+ let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
+ layer_i += 1;
+ let attn_block = if i == 0 {
+ None
+ } else {
+ let attn_block = AttnBlock::new(
+ c_hidden,
+ c_cond,
+ NHEAD[i],
+ true,
+ use_flash_attn,
+ vb.pp(layer_i),
+ )?;
+ layer_i += 1;
+ Some(attn_block)
+ };
+ let sub_block = SubBlock {
+ res_block,
+ ts_block,
+ attn_block,
+ };
+ sub_blocks.push(sub_block)
+ }
+ let down_block = DownBlock {
+ layer_norm,
+ conv,
+ sub_blocks,
+ };
+ down_blocks.push(down_block)
+ }
+
+ let mut up_blocks = Vec::with_capacity(C_HIDDEN.len());
+ for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() {
+ let vb = vb.pp("up_blocks").pp(C_HIDDEN.len() - 1 - i);
+ let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
+ let mut layer_i = 0;
+ for j in 0..BLOCKS[i] {
+ let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
+ let c_skip_res = if i < BLOCKS.len() - 1 && j == 0 {
+ c_hidden + c_skip
+ } else {
+ c_skip
+ };
+ let res_block = ResBlockStageB::new(c_hidden, c_skip_res, 3, vb.pp(layer_i))?;
+ layer_i += 1;
+ let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
+ layer_i += 1;
+ let attn_block = if i == 0 {
+ None
+ } else {
+ let attn_block = AttnBlock::new(
+ c_hidden,
+ c_cond,
+ NHEAD[i],
+ true,
+ use_flash_attn,
+ vb.pp(layer_i),
+ )?;
+ layer_i += 1;
+ Some(attn_block)
+ };
+ let sub_block = SubBlock {
+ res_block,
+ ts_block,
+ attn_block,
+ };
+ sub_blocks.push(sub_block)
+ }
+ let (layer_norm, conv) = if i > 0 {
+ let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
+ let cfg = candle_nn::ConvTranspose2dConfig {
+ stride: 2,
+ ..Default::default()
+ };
+ let conv = candle_nn::conv_transpose2d(
+ c_hidden,
+ C_HIDDEN[i - 1],
+ 2,
+ cfg,
+ vb.pp(layer_i).pp(1),
+ )?;
+ (Some(layer_norm), Some(conv))
+ } else {
+ (None, None)
+ };
+ let up_block = UpBlock {
+ layer_norm,
+ conv,
+ sub_blocks,
+ };
+ up_blocks.push(up_block)
+ }
+
+ let clf_ln = WLayerNorm::new(C_HIDDEN[0])?;
+ let clf_conv = candle_nn::conv2d(
+ C_HIDDEN[0],
+ 2 * c_out * patch_size * patch_size,
+ 1,
+ Default::default(),
+ vb.pp("clf.1"),
+ )?;
+ Ok(Self {
+ clip_mapper,
+ effnet_mappers,
+ seq_norm,
+ embedding_conv,
+ embedding_ln,
+ down_blocks,
+ up_blocks,
+ clf_ln,
+ clf_conv,
+ c_r,
+ patch_size,
+ })
+ }
+
+ fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> {
+ const MAX_POSITIONS: usize = 10000;
+ let r = (r * MAX_POSITIONS as f64)?;
+ let half_dim = self.c_r / 2;
+ let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64;
+ let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)?
+ * -emb)?
+ .exp()?;
+ let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
+ let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?;
+ let emb = if self.c_r % 2 == 1 {
+ emb.pad_with_zeros(D::Minus1, 0, 1)?
+ } else {
+ emb
+ };
+ emb.to_dtype(r.dtype())
+ }
+
+ fn gen_c_embeddings(&self, clip: &Tensor) -> Result<Tensor> {
+ clip.apply(&self.clip_mapper)?.apply(&self.seq_norm)
+ }
+
+ pub fn forward(
+ &self,
+ xs: &Tensor,
+ r: &Tensor,
+ effnet: &Tensor,
+ clip: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ const EPS: f64 = 1e-3;
+
+ let r_embed = self.gen_r_embedding(r)?;
+ let clip = match clip {
+ None => None,
+ Some(clip) => Some(self.gen_c_embeddings(clip)?),
+ };
+ let x_in = xs;
+
+ let mut xs = xs
+ .apply(&|xs: &_| candle_nn::ops::pixel_unshuffle(xs, self.patch_size))?
+ .apply(&self.embedding_conv)?
+ .apply(&self.embedding_ln)?;
+
+ let mut level_outputs = Vec::new();
+ for (i, down_block) in self.down_blocks.iter().enumerate() {
+ if let Some(ln) = &down_block.layer_norm {
+ xs = xs.apply(ln)?
+ }
+ if let Some(conv) = &down_block.conv {
+ xs = xs.apply(conv)?
+ }
+ let skip = match &self.effnet_mappers[i] {
+ None => None,
+ Some(m) => {
+ let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
+ Some(m.forward(&effnet)?)
+ }
+ };
+ for block in down_block.sub_blocks.iter() {
+ xs = block.res_block.forward(&xs, skip.as_ref())?;
+ xs = block.ts_block.forward(&xs, &r_embed)?;
+ if let Some(attn_block) = &block.attn_block {
+ xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
+ }
+ }
+ level_outputs.push(xs.clone())
+ }
+ level_outputs.reverse();
+ let mut xs = level_outputs[0].clone();
+
+ for (i, up_block) in self.up_blocks.iter().enumerate() {
+ let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] {
+ None => None,
+ Some(m) => {
+ let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
+ Some(m.forward(&effnet)?)
+ }
+ };
+ for (j, block) in up_block.sub_blocks.iter().enumerate() {
+ let skip = if j == 0 && i > 0 {
+ Some(&level_outputs[i])
+ } else {
+ None
+ };
+ let skip = match (skip, effnet_c.as_ref()) {
+ (Some(skip), Some(effnet_c)) => Some(Tensor::cat(&[skip, effnet_c], 1)?),
+ (None, Some(skip)) | (Some(skip), None) => Some(skip.clone()),
+ (None, None) => None,
+ };
+ xs = block.res_block.forward(&xs, skip.as_ref())?;
+ xs = block.ts_block.forward(&xs, &r_embed)?;
+ if let Some(attn_block) = &block.attn_block {
+ xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
+ }
+ }
+ if let Some(ln) = &up_block.layer_norm {
+ xs = xs.apply(ln)?
+ }
+ if let Some(conv) = &up_block.conv {
+ xs = xs.apply(conv)?
+ }
+ }
+
+ let ab = xs
+ .apply(&self.clf_ln)?
+ .apply(&self.clf_conv)?
+ .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))?
+ .chunk(2, 1)?;
+ let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?;
+ (x_in - &ab[0])? / b
+ }
+}
diff --git a/candle-transformers/src/models/wuerstchen/mod.rs b/candle-transformers/src/models/wuerstchen/mod.rs
new file mode 100644
index 00000000..7b076f06
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/mod.rs
@@ -0,0 +1,6 @@
+pub mod attention_processor;
+pub mod common;
+pub mod ddpm;
+pub mod diffnext;
+pub mod paella_vq;
+pub mod prior;
diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs
new file mode 100644
index 00000000..4a69cca0
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs
@@ -0,0 +1,211 @@
+use super::common::LayerNormNoWeights;
+use candle::{Module, Result, Tensor};
+use candle_nn::VarBuilder;
+
+#[derive(Debug)]
+pub struct MixingResidualBlock {
+ norm1: LayerNormNoWeights,
+ depthwise_conv: candle_nn::Conv2d,
+ norm2: LayerNormNoWeights,
+ channelwise_lin1: candle_nn::Linear,
+ channelwise_lin2: candle_nn::Linear,
+ gammas: Vec<f32>,
+}
+
+impl MixingResidualBlock {
+ pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
+ let norm1 = LayerNormNoWeights::new(inp)?;
+ let norm2 = LayerNormNoWeights::new(inp)?;
+ let cfg = candle_nn::Conv2dConfig {
+ groups: inp,
+ ..Default::default()
+ };
+ let depthwise_conv = candle_nn::conv2d(inp, inp, 3, cfg, vb.pp("depthwise.1"))?;
+ let channelwise_lin1 = candle_nn::linear(inp, embed_dim, vb.pp("channelwise.0"))?;
+ let channelwise_lin2 = candle_nn::linear(embed_dim, inp, vb.pp("channelwise.2"))?;
+ let gammas = vb.get(6, "gammas")?.to_vec1::<f32>()?;
+ Ok(Self {
+ norm1,
+ depthwise_conv,
+ norm2,
+ channelwise_lin1,
+ channelwise_lin2,
+ gammas,
+ })
+ }
+}
+
+impl Module for MixingResidualBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let mods = &self.gammas;
+ let x_temp = xs
+ .permute((0, 2, 3, 1))?
+ .apply(&self.norm1)?
+ .permute((0, 3, 1, 2))?
+ .affine(1. + mods[0] as f64, mods[1] as f64)?;
+ let x_temp = candle_nn::ops::replication_pad2d(&x_temp, 1)?;
+ let xs = (xs + x_temp.apply(&self.depthwise_conv)? * mods[2] as f64)?;
+ let x_temp = xs
+ .permute((0, 2, 3, 1))?
+ .apply(&self.norm2)?
+ .permute((0, 3, 1, 2))?
+ .affine(1. + mods[3] as f64, mods[4] as f64)?;
+ let x_temp = x_temp
+ .permute((0, 2, 3, 1))?
+ .contiguous()?
+ .apply(&self.channelwise_lin1)?
+ .gelu()?
+ .apply(&self.channelwise_lin2)?
+ .permute((0, 3, 1, 2))?;
+ xs + x_temp * mods[5] as f64
+ }
+}
+
+#[derive(Debug)]
+pub struct PaellaVQ {
+ in_block_conv: candle_nn::Conv2d,
+ out_block_conv: candle_nn::Conv2d,
+ down_blocks: Vec<(Option<candle_nn::Conv2d>, MixingResidualBlock)>,
+ down_blocks_conv: candle_nn::Conv2d,
+ down_blocks_bn: candle_nn::BatchNorm,
+ up_blocks_conv: candle_nn::Conv2d,
+ up_blocks: Vec<(Vec<MixingResidualBlock>, Option<candle_nn::ConvTranspose2d>)>,
+}
+
+impl PaellaVQ {
+ pub fn new(vb: VarBuilder) -> Result<Self> {
+ const IN_CHANNELS: usize = 3;
+ const OUT_CHANNELS: usize = 3;
+ const LATENT_CHANNELS: usize = 4;
+ const EMBED_DIM: usize = 384;
+ const BOTTLENECK_BLOCKS: usize = 12;
+ const C_LEVELS: [usize; 2] = [EMBED_DIM / 2, EMBED_DIM];
+
+ let in_block_conv = candle_nn::conv2d(
+ IN_CHANNELS * 4,
+ C_LEVELS[0],
+ 1,
+ Default::default(),
+ vb.pp("in_block.1"),
+ )?;
+ let out_block_conv = candle_nn::conv2d(
+ C_LEVELS[0],
+ OUT_CHANNELS * 4,
+ 1,
+ Default::default(),
+ vb.pp("out_block.0"),
+ )?;
+
+ let mut down_blocks = Vec::new();
+ let vb_d = vb.pp("down_blocks");
+ let mut d_idx = 0;
+ for (i, &c_level) in C_LEVELS.iter().enumerate() {
+ let conv_block = if i > 0 {
+ let cfg = candle_nn::Conv2dConfig {
+ padding: 1,
+ stride: 2,
+ ..Default::default()
+ };
+ let block = candle_nn::conv2d(C_LEVELS[i - 1], c_level, 4, cfg, vb_d.pp(d_idx))?;
+ d_idx += 1;
+ Some(block)
+ } else {
+ None
+ };
+ let res_block = MixingResidualBlock::new(c_level, c_level * 4, vb_d.pp(d_idx))?;
+ d_idx += 1;
+ down_blocks.push((conv_block, res_block))
+ }
+ let vb_d = vb_d.pp(d_idx);
+ let down_blocks_conv = candle_nn::conv2d_no_bias(
+ C_LEVELS[1],
+ LATENT_CHANNELS,
+ 1,
+ Default::default(),
+ vb_d.pp(0),
+ )?;
+ let down_blocks_bn = candle_nn::batch_norm(LATENT_CHANNELS, 1e-5, vb_d.pp(1))?;
+
+ let mut up_blocks = Vec::new();
+ let vb_u = vb.pp("up_blocks");
+ let mut u_idx = 0;
+ let up_blocks_conv = candle_nn::conv2d(
+ LATENT_CHANNELS,
+ C_LEVELS[1],
+ 1,
+ Default::default(),
+ vb_u.pp(u_idx).pp(0),
+ )?;
+ u_idx += 1;
+ for (i, &c_level) in C_LEVELS.iter().rev().enumerate() {
+ let mut res_blocks = Vec::new();
+ let n_bottleneck_blocks = if i == 0 { BOTTLENECK_BLOCKS } else { 1 };
+ for _j in 0..n_bottleneck_blocks {
+ let res_block = MixingResidualBlock::new(c_level, c_level * 4, vb_u.pp(u_idx))?;
+ u_idx += 1;
+ res_blocks.push(res_block)
+ }
+ let conv_block = if i < C_LEVELS.len() - 1 {
+ let cfg = candle_nn::ConvTranspose2dConfig {
+ padding: 1,
+ stride: 2,
+ ..Default::default()
+ };
+ let block = candle_nn::conv_transpose2d(
+ c_level,
+ C_LEVELS[C_LEVELS.len() - i - 2],
+ 4,
+ cfg,
+ vb_u.pp(u_idx),
+ )?;
+ u_idx += 1;
+ Some(block)
+ } else {
+ None
+ };
+ up_blocks.push((res_blocks, conv_block))
+ }
+ Ok(Self {
+ in_block_conv,
+ down_blocks,
+ down_blocks_conv,
+ down_blocks_bn,
+ up_blocks,
+ up_blocks_conv,
+ out_block_conv,
+ })
+ }
+
+ pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
+ let mut xs = candle_nn::ops::pixel_unshuffle(xs, 2)?.apply(&self.in_block_conv)?;
+ for down_block in self.down_blocks.iter() {
+ if let Some(conv) = &down_block.0 {
+ xs = xs.apply(conv)?
+ }
+ xs = xs.apply(&down_block.1)?
+ }
+ xs.apply(&self.down_blocks_conv)?
+ .apply(&self.down_blocks_bn)
+ }
+
+ pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
+ // TODO: quantizer if we want to support `force_not_quantize=False`.
+ let mut xs = xs.apply(&self.up_blocks_conv)?;
+ for up_block in self.up_blocks.iter() {
+ for b in up_block.0.iter() {
+ xs = xs.apply(b)?;
+ }
+ if let Some(conv) = &up_block.1 {
+ xs = xs.apply(conv)?
+ }
+ }
+ xs.apply(&self.out_block_conv)?
+ .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, 2))
+ }
+}
+
+impl Module for PaellaVQ {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ self.decode(&self.encode(xs)?)
+ }
+}
diff --git a/candle-transformers/src/models/wuerstchen/prior.rs b/candle-transformers/src/models/wuerstchen/prior.rs
new file mode 100644
index 00000000..97ccf0e2
--- /dev/null
+++ b/candle-transformers/src/models/wuerstchen/prior.rs
@@ -0,0 +1,103 @@
+use super::common::{AttnBlock, ResBlock, TimestepBlock};
+use candle::{DType, Result, Tensor, D};
+use candle_nn::VarBuilder;
+
+#[derive(Debug)]
+struct Block {
+ res_block: ResBlock,
+ ts_block: TimestepBlock,
+ attn_block: AttnBlock,
+}
+
+#[derive(Debug)]
+pub struct WPrior {
+ projection: candle_nn::Conv2d,
+ cond_mapper_lin1: candle_nn::Linear,
+ cond_mapper_lin2: candle_nn::Linear,
+ blocks: Vec<Block>,
+ out_ln: super::common::WLayerNorm,
+ out_conv: candle_nn::Conv2d,
+ c_r: usize,
+}
+
+impl WPrior {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new(
+ c_in: usize,
+ c: usize,
+ c_cond: usize,
+ c_r: usize,
+ depth: usize,
+ nhead: usize,
+ use_flash_attn: bool,
+ vb: VarBuilder,
+ ) -> Result<Self> {
+ let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
+ let cond_mapper_lin1 = candle_nn::linear(c_cond, c, vb.pp("cond_mapper.0"))?;
+ let cond_mapper_lin2 = candle_nn::linear(c, c, vb.pp("cond_mapper.2"))?;
+ let out_ln = super::common::WLayerNorm::new(c)?;
+ let out_conv = candle_nn::conv2d(c, c_in * 2, 1, Default::default(), vb.pp("out.1"))?;
+ let mut blocks = Vec::with_capacity(depth);
+ for index in 0..depth {
+ let res_block = ResBlock::new(c, 0, 3, vb.pp(format!("blocks.{}", 3 * index)))?;
+ let ts_block = TimestepBlock::new(c, c_r, vb.pp(format!("blocks.{}", 3 * index + 1)))?;
+ let attn_block = AttnBlock::new(
+ c,
+ c,
+ nhead,
+ true,
+ use_flash_attn,
+ vb.pp(format!("blocks.{}", 3 * index + 2)),
+ )?;
+ blocks.push(Block {
+ res_block,
+ ts_block,
+ attn_block,
+ })
+ }
+ Ok(Self {
+ projection,
+ cond_mapper_lin1,
+ cond_mapper_lin2,
+ blocks,
+ out_ln,
+ out_conv,
+ c_r,
+ })
+ }
+
+ pub fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> {
+ const MAX_POSITIONS: usize = 10000;
+ let r = (r * MAX_POSITIONS as f64)?;
+ let half_dim = self.c_r / 2;
+ let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64;
+ let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)?
+ * -emb)?
+ .exp()?;
+ let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
+ let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?;
+ let emb = if self.c_r % 2 == 1 {
+ emb.pad_with_zeros(D::Minus1, 0, 1)?
+ } else {
+ emb
+ };
+ emb.to_dtype(r.dtype())
+ }
+
+ pub fn forward(&self, xs: &Tensor, r: &Tensor, c: &Tensor) -> Result<Tensor> {
+ let x_in = xs;
+ let mut xs = xs.apply(&self.projection)?;
+ let c_embed = c
+ .apply(&self.cond_mapper_lin1)?
+ .apply(&|xs: &_| candle_nn::ops::leaky_relu(xs, 0.2))?
+ .apply(&self.cond_mapper_lin2)?;
+ let r_embed = self.gen_r_embedding(r)?;
+ for block in self.blocks.iter() {
+ xs = block.res_block.forward(&xs, None)?;
+ xs = block.ts_block.forward(&xs, &r_embed)?;
+ xs = block.attn_block.forward(&xs, &c_embed)?;
+ }
+ let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(2, 1)?;
+ (x_in - &ab[0])? / ((&ab[1] - 1.)?.abs()? + 1e-5)
+ }
+}
diff --git a/candle-transformers/src/object_detection.rs b/candle-transformers/src/object_detection.rs
new file mode 100644
index 00000000..ce579316
--- /dev/null
+++ b/candle-transformers/src/object_detection.rs
@@ -0,0 +1,52 @@
+/// A bounding box around an object.
+#[derive(Debug, Clone)]
+pub struct Bbox<D> {
+ pub xmin: f32,
+ pub ymin: f32,
+ pub xmax: f32,
+ pub ymax: f32,
+ pub confidence: f32,
+ pub data: D,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub struct KeyPoint {
+ pub x: f32,
+ pub y: f32,
+ pub mask: f32,
+}
+
+/// Intersection over union of two bounding boxes.
+pub fn iou<D>(b1: &Bbox<D>, b2: &Bbox<D>) -> f32 {
+ let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
+ let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
+ let i_xmin = b1.xmin.max(b2.xmin);
+ let i_xmax = b1.xmax.min(b2.xmax);
+ let i_ymin = b1.ymin.max(b2.ymin);
+ let i_ymax = b1.ymax.min(b2.ymax);
+ let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
+ i_area / (b1_area + b2_area - i_area)
+}
+
+pub fn non_maximum_suppression<D>(bboxes: &mut [Vec<Bbox<D>>], threshold: f32) {
+ // Perform non-maximum suppression.
+ for bboxes_for_class in bboxes.iter_mut() {
+ bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
+ let mut current_index = 0;
+ for index in 0..bboxes_for_class.len() {
+ let mut drop = false;
+ for prev_index in 0..current_index {
+ let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
+ if iou > threshold {
+ drop = true;
+ break;
+ }
+ }
+ if !drop {
+ bboxes_for_class.swap(current_index, index);
+ current_index += 1;
+ }
+ }
+ bboxes_for_class.truncate(current_index);
+ }
+}