diff options
Diffstat (limited to 'candle-transformers/src')
42 files changed, 12392 insertions, 17 deletions
diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index b1d20168..b1a567c3 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -1,35 +1,82 @@ -use candle::{DType, Error, Result, Tensor, D}; +use candle::{DType, Error, Result, Tensor}; use rand::{distributions::Distribution, SeedableRng}; pub struct LogitsProcessor { rng: rand::rngs::StdRng, temperature: Option<f64>, + top_p: Option<f64>, } impl LogitsProcessor { - pub fn new(seed: u64, temperature: Option<f64>) -> Self { + pub fn new(seed: u64, temperature: Option<f64>, top_p: Option<f64>) -> Self { + let temperature = if temperature.map_or(true, |v| v < 1e-7) { + None + } else { + temperature + }; Self { rng: rand::rngs::StdRng::seed_from_u64(seed), temperature, + top_p, + } + } + + fn sample_argmax(&mut self, logits: Tensor) -> Result<u32> { + let logits_v: Vec<f32> = logits.to_vec1()?; + let next_token = logits_v + .iter() + .enumerate() + .max_by(|(_, u), (_, v)| u.total_cmp(v)) + .map(|(i, _)| i as u32) + .unwrap(); + Ok(next_token) + } + + fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> { + let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?; + let next_token = distr.sample(&mut self.rng) as u32; + Ok(next_token) + } + + fn sample_topp(&mut self, prs: &mut Vec<f32>, top_p: f32) -> Result<u32> { + // top-p sampling (or "nucleus sampling") samples from the smallest set of + // tokens that exceed probability top_p. This way we never sample tokens that + // have very low probabilities and are less likely to go "off the rails". + let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>(); + + // Sort by descending probability. + argsort_indices.sort_by(|&i, &j| prs[j].partial_cmp(&prs[i]).unwrap()); + + // Clamp smaller probabilities to zero. + let mut cumsum = 0.; + for index in &argsort_indices { + if cumsum >= top_p { + prs[*index] = 0.0; + } else { + cumsum += prs[*index]; + } } + // Sample with clamped probabilities. + self.sample_multinomial(prs) } pub fn sample(&mut self, logits: &Tensor) -> Result<u32> { let logits = logits.to_dtype(DType::F32)?; - let temperature = self.temperature.unwrap_or(0.); - let next_token = if temperature > 0. { - let prs = candle_nn::ops::softmax(&(&logits / temperature)?, D::Minus1)?; - let prs: Vec<f32> = prs.to_vec1()?; - let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?; - distr.sample(&mut self.rng) as u32 - } else { - let logits_v: Vec<f32> = logits.to_vec1()?; - logits_v - .iter() - .enumerate() - .max_by(|(_, u), (_, v)| u.total_cmp(v)) - .map(|(i, _)| i as u32) - .unwrap() + let next_token = match self.temperature { + None => self.sample_argmax(logits)?, + Some(temperature) => { + let logits = &(&logits / temperature)?; + let prs = candle_nn::ops::softmax_last_dim(logits)?; + let mut prs: Vec<f32> = prs.to_vec1()?; + let top_p = self.top_p.unwrap_or(1.); + if top_p <= 0.0 || top_p >= 1.0 { + // simply sample from the predicted probability distribution + self.sample_multinomial(&prs)? + } else { + // top-p (nucleus) sampling, clamping the least likely tokens to zero + self.sample_topp(&mut prs, top_p as f32)? + } + } }; Ok(next_token) } diff --git a/candle-transformers/src/lib.rs b/candle-transformers/src/lib.rs index a8890dc8..b83e5056 100644 --- a/candle-transformers/src/lib.rs +++ b/candle-transformers/src/lib.rs @@ -1,4 +1,5 @@ pub mod generation; pub mod models; +pub mod object_detection; pub mod pipelines; pub mod utils; diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs new file mode 100644 index 00000000..3f164a3a --- /dev/null +++ b/candle-transformers/src/models/bert.rs @@ -0,0 +1,568 @@ +use candle::{DType, Device, Result, Tensor}; +use candle_nn::{Embedding, Module, VarBuilder}; +use serde::Deserialize; + +pub const DTYPE: DType = DType::F32; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "lowercase")] +enum HiddenAct { + Gelu, + Relu, +} + +struct HiddenActLayer { + act: HiddenAct, + span: tracing::Span, +} + +impl HiddenActLayer { + fn new(act: HiddenAct) -> Self { + let span = tracing::span!(tracing::Level::TRACE, "hidden-act"); + Self { act, span } + } + + fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> { + let _enter = self.span.enter(); + match self.act { + // TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some + // small numerical difference. + // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213 + HiddenAct::Gelu => xs.gelu(), + HiddenAct::Relu => xs.relu(), + } + } +} + +#[derive(Debug)] +pub struct Linear { + weight: Tensor, + bias: Option<Tensor>, + span: tracing::Span, +} + +impl Linear { + pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self { + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Self { weight, bias, span } + } + + pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { + let _enter = self.span.enter(); + let w = match x.dims() { + &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, + _ => self.weight.t()?, + }; + let x = x.matmul(&w)?; + match &self.bias { + None => Ok(x), + Some(bias) => x.broadcast_add(bias), + } + } +} + +#[derive(Debug)] +pub struct LayerNorm { + weight: Tensor, + bias: Tensor, + eps: f64, + span: tracing::Span, +} + +impl LayerNorm { + pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { + let span = tracing::span!(tracing::Level::TRACE, "layer-norm"); + Self { + weight, + bias, + eps, + span, + } + } + + pub fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let (_bsize, _seq_len, hidden_size) = x.dims3()?; + let x = x.to_dtype(internal_dtype)?; + let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; + let x = x.broadcast_sub(&mean_x)?; + let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; + let x = x_normed + .to_dtype(x_dtype)? + .broadcast_mul(&self.weight)? + .broadcast_add(&self.bias)?; + Ok(x) + } +} +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +enum PositionEmbeddingType { + #[default] + Absolute, +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1 +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + vocab_size: usize, + hidden_size: usize, + num_hidden_layers: usize, + num_attention_heads: usize, + intermediate_size: usize, + hidden_act: HiddenAct, + hidden_dropout_prob: f64, + max_position_embeddings: usize, + type_vocab_size: usize, + initializer_range: f64, + layer_norm_eps: f64, + pad_token_id: usize, + #[serde(default)] + position_embedding_type: PositionEmbeddingType, + #[serde(default)] + use_cache: bool, + classifier_dropout: Option<f64>, + model_type: Option<String>, +} + +impl Default for Config { + fn default() -> Self { + Self { + vocab_size: 30522, + hidden_size: 768, + num_hidden_layers: 12, + num_attention_heads: 12, + intermediate_size: 3072, + hidden_act: HiddenAct::Gelu, + hidden_dropout_prob: 0.1, + max_position_embeddings: 512, + type_vocab_size: 2, + initializer_range: 0.02, + layer_norm_eps: 1e-12, + pad_token_id: 0, + position_embedding_type: PositionEmbeddingType::Absolute, + use_cache: true, + classifier_dropout: None, + model_type: Some("bert".to_string()), + } + } +} + +impl Config { + fn _all_mini_lm_l6_v2() -> Self { + // https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json + Self { + vocab_size: 30522, + hidden_size: 384, + num_hidden_layers: 6, + num_attention_heads: 12, + intermediate_size: 1536, + hidden_act: HiddenAct::Gelu, + hidden_dropout_prob: 0.1, + max_position_embeddings: 512, + type_vocab_size: 2, + initializer_range: 0.02, + layer_norm_eps: 1e-12, + pad_token_id: 0, + position_embedding_type: PositionEmbeddingType::Absolute, + use_cache: true, + classifier_dropout: None, + model_type: Some("bert".to_string()), + } + } +} + +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; + Ok(Embedding::new(embeddings, hidden_size)) +} + +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { + let weight = vb.get((size2, size1), "weight")?; + let bias = vb.get(size2, "bias")?; + Ok(Linear::new(weight, Some(bias))) +} + +struct Dropout { + #[allow(dead_code)] + pr: f64, +} + +impl Dropout { + fn new(pr: f64) -> Self { + Self { pr } + } + + fn forward(&self, x: &Tensor) -> Result<Tensor> { + // TODO + Ok(x.clone()) + } +} + +fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { + let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { + (Ok(weight), Ok(bias)) => (weight, bias), + (Err(err), _) | (_, Err(err)) => { + if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { + (weight, bias) + } else { + return Err(err); + } + } + }; + Ok(LayerNorm::new(weight, bias, eps)) +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180 +struct BertEmbeddings { + word_embeddings: Embedding, + position_embeddings: Option<Embedding>, + token_type_embeddings: Embedding, + layer_norm: LayerNorm, + dropout: Dropout, + span: tracing::Span, +} + +impl BertEmbeddings { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let word_embeddings = embedding( + config.vocab_size, + config.hidden_size, + vb.pp("word_embeddings"), + )?; + let position_embeddings = embedding( + config.max_position_embeddings, + config.hidden_size, + vb.pp("position_embeddings"), + )?; + let token_type_embeddings = embedding( + config.type_vocab_size, + config.hidden_size, + vb.pp("token_type_embeddings"), + )?; + let layer_norm = layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?; + Ok(Self { + word_embeddings, + position_embeddings: Some(position_embeddings), + token_type_embeddings, + layer_norm, + dropout: Dropout::new(config.hidden_dropout_prob), + span: tracing::span!(tracing::Level::TRACE, "embeddings"), + }) + } + + fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let (_bsize, seq_len) = input_ids.dims2()?; + let input_embeddings = self.word_embeddings.forward(input_ids)?; + let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; + let mut embeddings = (&input_embeddings + token_type_embeddings)?; + if let Some(position_embeddings) = &self.position_embeddings { + // TODO: Proper absolute positions? + let position_ids = (0..seq_len as u32).collect::<Vec<_>>(); + let position_ids = Tensor::new(&position_ids[..], input_ids.device())?; + embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)? + } + let embeddings = self.layer_norm.forward(&embeddings)?; + let embeddings = self.dropout.forward(&embeddings)?; + Ok(embeddings) + } +} + +struct BertSelfAttention { + query: Linear, + key: Linear, + value: Linear, + dropout: Dropout, + num_attention_heads: usize, + attention_head_size: usize, + span: tracing::Span, + span_softmax: tracing::Span, +} + +impl BertSelfAttention { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let attention_head_size = config.hidden_size / config.num_attention_heads; + let all_head_size = config.num_attention_heads * attention_head_size; + let dropout = Dropout::new(config.hidden_dropout_prob); + let hidden_size = config.hidden_size; + let query = linear(hidden_size, all_head_size, vb.pp("query"))?; + let value = linear(hidden_size, all_head_size, vb.pp("value"))?; + let key = linear(hidden_size, all_head_size, vb.pp("key"))?; + Ok(Self { + query, + key, + value, + dropout, + num_attention_heads: config.num_attention_heads, + attention_head_size, + span: tracing::span!(tracing::Level::TRACE, "self-attn"), + span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"), + }) + } + + fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> { + let mut new_x_shape = xs.dims().to_vec(); + new_x_shape.pop(); + new_x_shape.push(self.num_attention_heads); + new_x_shape.push(self.attention_head_size); + let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; + xs.contiguous() + } + + fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let query_layer = self.query.forward(hidden_states)?; + let key_layer = self.key.forward(hidden_states)?; + let value_layer = self.value.forward(hidden_states)?; + + let query_layer = self.transpose_for_scores(&query_layer)?; + let key_layer = self.transpose_for_scores(&key_layer)?; + let value_layer = self.transpose_for_scores(&value_layer)?; + + let attention_scores = query_layer.matmul(&key_layer.t()?)?; + let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; + let attention_probs = { + let _enter_sm = self.span_softmax.enter(); + candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)? + }; + let attention_probs = self.dropout.forward(&attention_probs)?; + + let context_layer = attention_probs.matmul(&value_layer)?; + let context_layer = context_layer.transpose(1, 2)?.contiguous()?; + let context_layer = context_layer.flatten_from(candle::D::Minus2)?; + Ok(context_layer) + } +} + +struct BertSelfOutput { + dense: Linear, + layer_norm: LayerNorm, + dropout: Dropout, + span: tracing::Span, +} + +impl BertSelfOutput { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; + let layer_norm = layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?; + let dropout = Dropout::new(config.hidden_dropout_prob); + Ok(Self { + dense, + layer_norm, + dropout, + span: tracing::span!(tracing::Level::TRACE, "self-out"), + }) + } + + fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let hidden_states = self.dense.forward(hidden_states)?; + let hidden_states = self.dropout.forward(&hidden_states)?; + self.layer_norm.forward(&(hidden_states + input_tensor)?) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392 +struct BertAttention { + self_attention: BertSelfAttention, + self_output: BertSelfOutput, + span: tracing::Span, +} + +impl BertAttention { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let self_attention = BertSelfAttention::load(vb.pp("self"), config)?; + let self_output = BertSelfOutput::load(vb.pp("output"), config)?; + Ok(Self { + self_attention, + self_output, + span: tracing::span!(tracing::Level::TRACE, "attn"), + }) + } + + fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let self_outputs = self.self_attention.forward(hidden_states)?; + let attention_output = self.self_output.forward(&self_outputs, hidden_states)?; + Ok(attention_output) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441 +struct BertIntermediate { + dense: Linear, + intermediate_act: HiddenActLayer, + span: tracing::Span, +} + +impl BertIntermediate { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?; + Ok(Self { + dense, + intermediate_act: HiddenActLayer::new(config.hidden_act), + span: tracing::span!(tracing::Level::TRACE, "inter"), + }) + } + + fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let hidden_states = self.dense.forward(hidden_states)?; + let ys = self.intermediate_act.forward(&hidden_states)?; + Ok(ys) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456 +struct BertOutput { + dense: Linear, + layer_norm: LayerNorm, + dropout: Dropout, + span: tracing::Span, +} + +impl BertOutput { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?; + let layer_norm = layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?; + let dropout = Dropout::new(config.hidden_dropout_prob); + Ok(Self { + dense, + layer_norm, + dropout, + span: tracing::span!(tracing::Level::TRACE, "out"), + }) + } + + fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let hidden_states = self.dense.forward(hidden_states)?; + let hidden_states = self.dropout.forward(&hidden_states)?; + self.layer_norm.forward(&(hidden_states + input_tensor)?) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470 +struct BertLayer { + attention: BertAttention, + intermediate: BertIntermediate, + output: BertOutput, + span: tracing::Span, +} + +impl BertLayer { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let attention = BertAttention::load(vb.pp("attention"), config)?; + let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?; + let output = BertOutput::load(vb.pp("output"), config)?; + Ok(Self { + attention, + intermediate, + output, + span: tracing::span!(tracing::Level::TRACE, "layer"), + }) + } + + fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let attention_output = self.attention.forward(hidden_states)?; + // TODO: Support cross-attention? + // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523 + // TODO: Support something similar to `apply_chunking_to_forward`? + let intermediate_output = self.intermediate.forward(&attention_output)?; + let layer_output = self + .output + .forward(&intermediate_output, &attention_output)?; + Ok(layer_output) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556 +struct BertEncoder { + layers: Vec<BertLayer>, + span: tracing::Span, +} + +impl BertEncoder { + fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let layers = (0..config.num_hidden_layers) + .map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config)) + .collect::<Result<Vec<_>>>()?; + let span = tracing::span!(tracing::Level::TRACE, "encoder"); + Ok(BertEncoder { layers, span }) + } + + fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut hidden_states = hidden_states.clone(); + // Use a loop rather than a fold as it's easier to modify when adding debug/... + for layer in self.layers.iter() { + hidden_states = layer.forward(&hidden_states)? + } + Ok(hidden_states) + } +} + +// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874 +pub struct BertModel { + embeddings: BertEmbeddings, + encoder: BertEncoder, + pub device: Device, + span: tracing::Span, +} + +impl BertModel { + pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> { + let (embeddings, encoder) = match ( + BertEmbeddings::load(vb.pp("embeddings"), config), + BertEncoder::load(vb.pp("encoder"), config), + ) { + (Ok(embeddings), Ok(encoder)) => (embeddings, encoder), + (Err(err), _) | (_, Err(err)) => { + if let Some(model_type) = &config.model_type { + if let (Ok(embeddings), Ok(encoder)) = ( + BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config), + BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config), + ) { + (embeddings, encoder) + } else { + return Err(err); + } + } else { + return Err(err); + } + } + }; + Ok(Self { + embeddings, + encoder, + device: vb.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } + + pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?; + let sequence_output = self.encoder.forward(&embedding_output)?; + Ok(sequence_output) + } +} diff --git a/candle-transformers/src/models/bigcode.rs b/candle-transformers/src/models/bigcode.rs new file mode 100644 index 00000000..1e63956b --- /dev/null +++ b/candle-transformers/src/models/bigcode.rs @@ -0,0 +1,359 @@ +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; + +fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> { + let weight = vb.get((size2, size1), "weight")?; + let bias = if bias { + Some(vb.get(size2, "bias")?) + } else { + None + }; + Ok(Linear::new(weight, bias)) +} + +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; + Ok(Embedding::new(embeddings, hidden_size)) +} + +fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { + let weight = vb.get(size, "weight")?; + let bias = vb.get(size, "bias")?; + Ok(LayerNorm::new(weight, bias, eps)) +} + +fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j <= i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), device)?; + Ok(mask) +} + +#[derive(Debug)] +pub struct Config { + pub vocab_size: usize, + // max_position_embeddings aka n_positions + pub max_position_embeddings: usize, + // num_hidden_layers aka n_layer + pub num_hidden_layers: usize, + // hidden_size aka n_embd + pub hidden_size: usize, + pub layer_norm_epsilon: f64, + pub n_inner: Option<usize>, + // num_attention_heads aka n_head + pub num_attention_heads: usize, + pub multi_query: bool, + pub use_cache: bool, +} + +impl Config { + #[allow(dead_code)] + pub fn starcoder_1b() -> Self { + Self { + vocab_size: 49152, + max_position_embeddings: 8192, + num_hidden_layers: 24, + hidden_size: 2048, + layer_norm_epsilon: 1e-5, + n_inner: Some(8192), + num_attention_heads: 16, + multi_query: true, + use_cache: true, + } + } + + #[allow(dead_code)] + pub fn starcoder_3b() -> Self { + Self { + vocab_size: 49152, + max_position_embeddings: 8192, + num_hidden_layers: 36, + hidden_size: 2816, + layer_norm_epsilon: 1e-5, + n_inner: Some(11264), + num_attention_heads: 22, + multi_query: true, + use_cache: true, + } + } + + #[allow(dead_code)] + pub fn starcoder_7b() -> Self { + Self { + vocab_size: 49152, + max_position_embeddings: 8192, + num_hidden_layers: 42, + hidden_size: 4096, + layer_norm_epsilon: 1e-5, + n_inner: Some(16384), + num_attention_heads: 32, + multi_query: true, + use_cache: true, + } + } + + #[allow(dead_code)] + pub fn starcoder() -> Self { + Self { + vocab_size: 49152, + max_position_embeddings: 8192, + num_hidden_layers: 40, + hidden_size: 6144, + layer_norm_epsilon: 1e-5, + n_inner: Some(24576), + num_attention_heads: 48, + multi_query: true, + use_cache: true, + } + } +} + +struct Attention { + c_attn: Linear, + c_proj: Linear, + kv_cache: Option<Tensor>, + use_cache: bool, + embed_dim: usize, + kv_dim: usize, + num_heads: usize, + head_dim: usize, + multi_query: bool, +} + +impl Attention { + pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let hidden_size = cfg.hidden_size; + let head_dim = hidden_size / cfg.num_attention_heads; + let kv_heads = if cfg.multi_query { + 1 + } else { + cfg.num_attention_heads + }; + let kv_dim = kv_heads * head_dim; + let c_attn = linear(hidden_size, hidden_size + 2 * kv_dim, true, vb.pp("c_attn"))?; + let c_proj = linear(hidden_size, hidden_size, true, vb.pp("c_proj"))?; + Ok(Self { + c_proj, + c_attn, + embed_dim: hidden_size, + kv_cache: None, + use_cache: cfg.use_cache, + kv_dim, + head_dim, + num_heads: cfg.num_attention_heads, + multi_query: cfg.multi_query, + }) + } + + fn attn( + &self, + query: &Tensor, + key: &Tensor, + value: &Tensor, + attention_mask: &Tensor, + ) -> Result<Tensor> { + if query.dtype() != DType::F32 { + // If we start supporting f16 models, we may need the upcasting scaling bits. + // https://github.com/huggingface/transformers/blob/a0042379269bea9182c1f87e6b2eee4ba4c8cce8/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L133 + candle::bail!("upcasting is not supported {:?}", query.dtype()) + } + let scale_factor = 1f64 / (self.head_dim as f64).sqrt(); + let initial_query_shape = query.shape(); + let key_len = key.dim(D::Minus1)?; + let (query, key, attn_shape, attn_view) = if self.multi_query { + let (b_sz, query_len, _) = query.dims3()?; + let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?; + let attn_shape = (b_sz, query_len, self.num_heads, key_len); + let attn_view = (b_sz, query_len * self.num_heads, key_len); + (query, key.clone(), attn_shape, attn_view) + } else { + let (b_sz, _num_heads, query_len, _head_dim) = query.dims4()?; + let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?; + let key = key.reshape((b_sz * self.num_heads, self.head_dim, key_len))?; + let attn_shape = (b_sz, self.num_heads, query_len, key_len); + let attn_view = (b_sz * self.num_heads, query_len, key_len); + (query, key, attn_shape, attn_view) + }; + + let attn_weights = + (query.matmul(&key.contiguous()?)? * scale_factor)?.reshape(attn_shape)?; + let attention_mask = attention_mask.broadcast_as(attn_shape)?; + let mask_value = + Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?; + let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?; + let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; + let value = value.contiguous()?; + let attn_output = if self.multi_query { + attn_weights + .reshape(attn_view)? + .matmul(&value)? + .reshape(initial_query_shape)? + } else { + attn_weights.matmul(&value)? + }; + Ok(attn_output) + } + + fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { + let qkv = self.c_attn.forward(hidden_states)?; + let (query, key_value) = if self.multi_query { + let query = qkv.i((.., .., ..self.embed_dim))?; + let key_value = qkv.i((.., .., self.embed_dim..self.embed_dim + 2 * self.kv_dim))?; + (query, key_value) + } else { + let mut dims = qkv.dims().to_vec(); + dims.pop(); + dims.push(self.embed_dim); + dims.push(self.head_dim * 3); + let qkv = qkv.reshape(dims)?.transpose(1, 2)?; + let query = qkv.i((.., .., .., ..self.head_dim))?; + let key_value = qkv.i((.., .., .., self.head_dim..3 * self.head_dim))?; + (query, key_value) + }; + let mut key_value = key_value; + if self.use_cache { + if let Some(kv_cache) = &self.kv_cache { + // TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for + // arbitrarily large sizes. + key_value = Tensor::cat(&[kv_cache, &key_value], D::Minus2)?.contiguous()?; + } + self.kv_cache = Some(key_value.clone()) + } + + let key = key_value.narrow(D::Minus1, 0, self.head_dim)?; + let value = key_value.narrow(D::Minus1, self.head_dim, self.head_dim)?; + let attn_output = self.attn(&query, &key.t()?, &value, attention_mask)?; + let attn_output = if self.multi_query { + attn_output + } else { + attn_output + .transpose(1, 2)? + .reshape(hidden_states.shape())? + }; + let attn_output = self.c_proj.forward(&attn_output)?; + Ok(attn_output) + } +} + +struct Mlp { + c_fc: Linear, + c_proj: Linear, +} + +impl Mlp { + fn load(inner_dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> { + let c_fc = linear(cfg.hidden_size, inner_dim, true, vb.pp("c_fc"))?; + let c_proj = linear(inner_dim, cfg.hidden_size, true, vb.pp("c_proj"))?; + Ok(Self { c_fc, c_proj }) + } + + fn forward(&mut self, hidden_states: &Tensor) -> Result<Tensor> { + let hidden_states = self.c_fc.forward(hidden_states)?.gelu()?; + let hidden_states = self.c_proj.forward(&hidden_states)?; + Ok(hidden_states) + } +} + +// TODO: Add cross-attention? +struct Block { + ln_1: LayerNorm, + attn: Attention, + ln_2: LayerNorm, + mlp: Mlp, +} + +impl Block { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let hidden_size = cfg.hidden_size; + let inner_dim = cfg.n_inner.unwrap_or(4 * hidden_size); + let ln_1 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_1"))?; + let attn = Attention::load(vb.pp("attn"), cfg)?; + let ln_2 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_2"))?; + let mlp = Mlp::load(inner_dim, vb.pp("mlp"), cfg)?; + Ok(Self { + ln_1, + attn, + ln_2, + mlp, + }) + } + + fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { + let residual = hidden_states; + let hidden_states = self.ln_1.forward(hidden_states)?; + let attn_outputs = self.attn.forward(&hidden_states, attention_mask)?; + let hidden_states = (&attn_outputs + residual)?; + let residual = &hidden_states; + let hidden_states = self.ln_2.forward(&hidden_states)?; + let hidden_states = self.mlp.forward(&hidden_states)?; + let hidden_states = (&hidden_states + residual)?; + Ok(hidden_states) + } +} + +pub struct GPTBigCode { + wte: Embedding, + wpe: Embedding, + blocks: Vec<Block>, + ln_f: LayerNorm, + lm_head: Linear, + bias: Tensor, + config: Config, +} + +impl GPTBigCode { + pub fn config(&self) -> &Config { + &self.config + } + + pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> { + let hidden_size = cfg.hidden_size; + let vb_t = vb.pp("transformer"); + let wte = embedding(cfg.vocab_size, hidden_size, vb_t.pp("wte"))?; + let wpe = embedding(cfg.max_position_embeddings, hidden_size, vb_t.pp("wpe"))?; + let blocks = (0..cfg.num_hidden_layers) + .map(|i| Block::load(vb_t.pp(&format!("h.{i}")), &cfg)) + .collect::<Result<Vec<_>>>()?; + let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?; + let lm_head = linear(hidden_size, cfg.vocab_size, false, vb_t.pp("wte"))?; + let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?; + Ok(Self { + wte, + wpe, + blocks, + lm_head, + ln_f, + bias, + config: cfg, + }) + } + + pub fn forward(&mut self, input_ids: &Tensor, past_len: usize) -> Result<Tensor> { + let dev = input_ids.device(); + let (b_sz, seq_len) = input_ids.dims2()?; + + let key_len = past_len + seq_len; + let attention_mask = self.bias.i((past_len..key_len, ..key_len))?.unsqueeze(0)?; + // MQA models: (batch_size, query_length, n_heads, key_length) + // MHA models: (batch_size, n_heads, query_length, key_length) + let seq_len_dim = if self.config.multi_query { 2 } else { 1 }; + let attention_mask = attention_mask.unsqueeze(seq_len_dim)?; + + let position_ids = Tensor::arange(past_len as u32, (past_len + seq_len) as u32, dev)?; + let position_ids = position_ids.unsqueeze(0)?.broadcast_as((b_sz, seq_len))?; + let input_embeds = self.wte.forward(input_ids)?; + let position_embeds = self.wpe.forward(&position_ids)?; + + let mut hidden_states = (&input_embeds + &position_embeds)?; + for block in self.blocks.iter_mut() { + hidden_states = block.forward(&hidden_states, &attention_mask)?; + } + let hidden_states = self.ln_f.forward(&hidden_states)?; + let hidden_states = hidden_states + .reshape((b_sz, seq_len, self.config.hidden_size))? + .narrow(1, seq_len - 1, 1)?; + let logits = self.lm_head.forward(&hidden_states)?.squeeze(1)?; + Ok(logits) + } +} diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs new file mode 100644 index 00000000..0edc8494 --- /dev/null +++ b/candle-transformers/src/models/dinov2.rs @@ -0,0 +1,279 @@ +use candle::{IndexOp, Result, Tensor, D}; +use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; + +const IMG_SIZE: usize = 518; +const PATCH_SIZE: usize = 14; +const NUM_CLASSES: usize = 1000; + +fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> { + if bias { + candle_nn::linear(in_dim, out_dim, vb) + } else { + candle_nn::linear_no_bias(in_dim, out_dim, vb) + } +} + +#[derive(Debug)] +struct Attention { + qkv: Linear, + proj: Linear, + num_heads: usize, + scale: f64, +} + +impl Attention { + fn new( + vb: VarBuilder, + dim: usize, + num_heads: usize, + qkv_bias: bool, + proj_bias: bool, + ) -> Result<Self> { + let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; + let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?; + let scale = 1. / ((dim / num_heads) as f64).sqrt(); + Ok(Self { + qkv, + proj, + num_heads, + scale, + }) + } +} + +impl Module for Attention { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let (b, n, c) = xs.dims3()?; + let qkv = self + .qkv + .forward(xs)? + .reshape((b, n, 3, self.num_heads, c / self.num_heads))? + .transpose(1, 2)? // 02134 + .transpose(0, 1)? // 20134 + .transpose(2, 3)?; // 20314 + let q = (qkv.i(0)? * self.scale)?; + let k = qkv.i(1)?; + let v = qkv.i(2)?; + let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?; + let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?; + self.proj.forward(&attn) + } +} + +#[derive(Debug)] +struct LayerScale { + gamma: Tensor, +} + +impl LayerScale { + fn new(vb: VarBuilder, dim: usize) -> Result<Self> { + let gamma = vb.get(dim, "gamma")?; + Ok(Self { gamma }) + } +} + +impl Module for LayerScale { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + xs.broadcast_mul(&self.gamma) + } +} + +#[derive(Debug)] +struct Mlp { + fc1: Linear, + fc2: Linear, +} + +impl Mlp { + fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> { + let out_features = in_features; + let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?; + let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?; + Ok(Self { fc1, fc2 }) + } +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.fc1.forward(xs)?.gelu()?; + self.fc2.forward(&xs) + } +} + +#[derive(Debug)] +struct Block { + norm1: LayerNorm, + attn: Attention, + ls1: LayerScale, + norm2: LayerNorm, + mlp: Mlp, + ls2: LayerScale, +} + +impl Block { + fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> { + let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?; + let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?; + let ls1 = LayerScale::new(vb.pp("ls1"), dim)?; + let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?; + let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?; + let ls2 = LayerScale::new(vb.pp("ls2"), dim)?; + Ok(Self { + norm1, + attn, + ls1, + norm2, + mlp, + ls2, + }) + } +} + +impl Module for Block { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let residual = xs; + let xs = self + .ls1 + .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = self + .ls2 + .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?; + xs + residual + } +} + +#[derive(Debug)] +struct PatchEmbed { + proj: candle_nn::Conv2d, + patch_size: (usize, usize), + num_patches: usize, +} + +impl PatchEmbed { + fn new( + vb: VarBuilder, + img_size: usize, + patch_size: usize, + in_chans: usize, + embed_dim: usize, + ) -> Result<Self> { + let config = candle_nn::Conv2dConfig { + stride: patch_size, + ..Default::default() + }; + let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?; + let num_patches = (img_size / patch_size) * (img_size / patch_size); + Ok(Self { + proj, + patch_size: (patch_size, patch_size), + num_patches, + }) + } +} + +impl Module for PatchEmbed { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let (_b, _c, h, w) = xs.dims4()?; + let (patch_h, patch_w) = self.patch_size; + if (h % patch_h) != 0 { + candle::bail!("image height {h} is not a multiple of patch height {patch_h}") + } + if (w % patch_w) != 0 { + candle::bail!("image width {w} is not a multiple of patch width {patch_w}") + } + let xs = self.proj.forward(xs)?; + let (b, c, h, w) = xs.dims4()?; + // flatten embeddings. + xs.reshape((b, c, h * w))?.transpose(1, 2) + } +} + +#[derive(Debug)] +pub struct DinoVisionTransformer { + patch_embed: PatchEmbed, + cls_token: Tensor, + pos_embed: Tensor, + blocks: Vec<Block>, + norm: LayerNorm, + head: Linear, +} + +impl DinoVisionTransformer { + pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> { + let patch_embed = + PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?; + let cls_token = vb.get((1, 1, embed_dim), "cls_token")?; + let num_tokens = 1; + let pos_embed = vb.get( + (1, patch_embed.num_patches + num_tokens, embed_dim), + "pos_embed", + )?; + let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?; + let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?; + let vb_b = vb.pp("blocks"); + let blocks = (0..depth) + .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads)) + .collect::<Result<Vec<_>>>()?; + Ok(Self { + patch_embed, + cls_token, + pos_embed, + blocks, + norm, + head, + }) + } + + fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> { + let npatch = xs.dim(1)? - 1; + let n = self.pos_embed.dim(1)? - 1; + let sqrt_n = (n as f64).sqrt(); + if npatch == n && w == h { + return Ok(xs.clone()); + } + let class_pos_embed = self.pos_embed.i((.., ..1))?; + let patch_pos_embed = self.pos_embed.i((.., 1..))?; + let dim = xs.dim(D::Minus1)?; + let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1); + let patch_pos_embed = patch_pos_embed + .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))? + .transpose(2, 3)? + .transpose(1, 2)?; + // This uses bicubic interpolation in the original implementation. + let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?; + let el_count = patch_pos_embed.shape().elem_count(); + let patch_pos_embed = + patch_pos_embed + .transpose(1, 2)? + .transpose(2, 3)? + .reshape((1, el_count / dim, dim))?; + Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1) + } + + fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> { + let (_b, _nc, w, h) = xs.dims4()?; + let xs = self.patch_embed.forward(xs)?; + let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?; + &xs + &self.interpolate_pos_encoding(&xs, w, h)? + } +} + +impl Module for DinoVisionTransformer { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = self.prepare_tokens_with_mask(xs)?; + for blk in self.blocks.iter() { + xs = blk.forward(&xs)? + } + let xs = self.norm.forward(&xs)?; + let xs_norm_clstoken = xs.i((.., 0))?; + let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?; + let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?; + self.head.forward(&xs) + } +} + +pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> { + DinoVisionTransformer::new(vb, 12, 384, 6) +} diff --git a/candle-transformers/src/models/efficientnet.rs b/candle-transformers/src/models/efficientnet.rs new file mode 100644 index 00000000..ab51c76d --- /dev/null +++ b/candle-transformers/src/models/efficientnet.rs @@ -0,0 +1,331 @@ +use candle::{Result, Tensor, D}; +use candle_nn as nn; +use nn::{Module, VarBuilder}; + +// Based on the Python version from torchvision. +// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47 +#[derive(Debug, Clone, Copy)] +pub struct MBConvConfig { + expand_ratio: f64, + kernel: usize, + stride: usize, + input_channels: usize, + out_channels: usize, + num_layers: usize, +} + +fn make_divisible(v: f64, divisor: usize) -> usize { + let min_value = divisor; + let new_v = usize::max( + min_value, + (v + divisor as f64 * 0.5) as usize / divisor * divisor, + ); + if (new_v as f64) < 0.9 * v { + new_v + divisor + } else { + new_v + } +} + +fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> { + let bneck_conf = |e, k, s, i, o, n| { + let input_channels = make_divisible(i as f64 * width_mult, 8); + let out_channels = make_divisible(o as f64 * width_mult, 8); + let num_layers = (n as f64 * depth_mult).ceil() as usize; + MBConvConfig { + expand_ratio: e, + kernel: k, + stride: s, + input_channels, + out_channels, + num_layers, + } + }; + vec![ + bneck_conf(1., 3, 1, 32, 16, 1), + bneck_conf(6., 3, 2, 16, 24, 2), + bneck_conf(6., 5, 2, 24, 40, 2), + bneck_conf(6., 3, 2, 40, 80, 3), + bneck_conf(6., 5, 1, 80, 112, 3), + bneck_conf(6., 5, 2, 112, 192, 4), + bneck_conf(6., 3, 1, 192, 320, 1), + ] +} + +impl MBConvConfig { + pub fn b0() -> Vec<Self> { + bneck_confs(1.0, 1.0) + } + pub fn b1() -> Vec<Self> { + bneck_confs(1.0, 1.1) + } + pub fn b2() -> Vec<Self> { + bneck_confs(1.1, 1.2) + } + pub fn b3() -> Vec<Self> { + bneck_confs(1.2, 1.4) + } + pub fn b4() -> Vec<Self> { + bneck_confs(1.4, 1.8) + } + pub fn b5() -> Vec<Self> { + bneck_confs(1.6, 2.2) + } + pub fn b6() -> Vec<Self> { + bneck_confs(1.8, 2.6) + } + pub fn b7() -> Vec<Self> { + bneck_confs(2.0, 3.1) + } +} + +/// Conv2D with same padding. +#[derive(Debug)] +struct Conv2DSame { + conv2d: nn::Conv2d, + s: usize, + k: usize, +} + +impl Conv2DSame { + fn new( + vb: VarBuilder, + i: usize, + o: usize, + k: usize, + stride: usize, + groups: usize, + bias: bool, + ) -> Result<Self> { + let conv_config = nn::Conv2dConfig { + stride, + groups, + ..Default::default() + }; + let conv2d = if bias { + nn::conv2d(i, o, k, conv_config, vb)? + } else { + nn::conv2d_no_bias(i, o, k, conv_config, vb)? + }; + Ok(Self { + conv2d, + s: stride, + k, + }) + } +} + +impl Module for Conv2DSame { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let s = self.s; + let k = self.k; + let (_, _, ih, iw) = xs.dims4()?; + let oh = (ih + s - 1) / s; + let ow = (iw + s - 1) / s; + let pad_h = usize::max((oh - 1) * s + k - ih, 0); + let pad_w = usize::max((ow - 1) * s + k - iw, 0); + if pad_h > 0 || pad_w > 0 { + let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?; + let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?; + self.conv2d.forward(&xs) + } else { + self.conv2d.forward(xs) + } + } +} + +#[derive(Debug)] +struct ConvNormActivation { + conv2d: Conv2DSame, + bn2d: nn::BatchNorm, + activation: bool, +} + +impl ConvNormActivation { + fn new( + vb: VarBuilder, + i: usize, + o: usize, + k: usize, + stride: usize, + groups: usize, + ) -> Result<Self> { + let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?; + let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?; + Ok(Self { + conv2d, + bn2d, + activation: true, + }) + } + + fn no_activation(self) -> Self { + Self { + activation: false, + ..self + } + } +} + +impl Module for ConvNormActivation { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.conv2d.forward(xs)?; + let xs = self.bn2d.forward(&xs)?; + if self.activation { + swish(&xs) + } else { + Ok(xs) + } + } +} + +#[derive(Debug)] +struct SqueezeExcitation { + fc1: Conv2DSame, + fc2: Conv2DSame, +} + +impl SqueezeExcitation { + fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> { + let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?; + let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?; + Ok(Self { fc1, fc2 }) + } +} + +impl Module for SqueezeExcitation { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let residual = xs; + // equivalent to adaptive_avg_pool2d([1, 1]) + let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?; + let xs = self.fc1.forward(&xs)?; + let xs = swish(&xs)?; + let xs = self.fc2.forward(&xs)?; + let xs = nn::ops::sigmoid(&xs)?; + residual.broadcast_mul(&xs) + } +} + +#[derive(Debug)] +struct MBConv { + expand_cna: Option<ConvNormActivation>, + depthwise_cna: ConvNormActivation, + squeeze_excitation: SqueezeExcitation, + project_cna: ConvNormActivation, + config: MBConvConfig, +} + +impl MBConv { + fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> { + let vb = vb.pp("block"); + let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8); + let expand_cna = if exp != c.input_channels { + Some(ConvNormActivation::new( + vb.pp("0"), + c.input_channels, + exp, + 1, + 1, + 1, + )?) + } else { + None + }; + let start_index = if expand_cna.is_some() { 1 } else { 0 }; + let depthwise_cna = + ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?; + let squeeze_channels = usize::max(1, c.input_channels / 4); + let squeeze_excitation = + SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?; + let project_cna = + ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)? + .no_activation(); + Ok(Self { + expand_cna, + depthwise_cna, + squeeze_excitation, + project_cna, + config: c, + }) + } +} + +impl Module for MBConv { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let use_res_connect = + self.config.stride == 1 && self.config.input_channels == self.config.out_channels; + let ys = match &self.expand_cna { + Some(expand_cna) => expand_cna.forward(xs)?, + None => xs.clone(), + }; + let ys = self.depthwise_cna.forward(&ys)?; + let ys = self.squeeze_excitation.forward(&ys)?; + let ys = self.project_cna.forward(&ys)?; + if use_res_connect { + ys + xs + } else { + Ok(ys) + } + } +} + +fn swish(s: &Tensor) -> Result<Tensor> { + s * nn::ops::sigmoid(s)? +} + +#[derive(Debug)] +pub struct EfficientNet { + init_cna: ConvNormActivation, + blocks: Vec<MBConv>, + final_cna: ConvNormActivation, + classifier: nn::Linear, +} + +impl EfficientNet { + pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> { + let f_p = p.pp("features"); + let first_in_c = configs[0].input_channels; + let last_out_c = configs.last().unwrap().out_channels; + let final_out_c = 4 * last_out_c; + let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?; + let nconfigs = configs.len(); + let mut blocks = vec![]; + for (index, cnf) in configs.into_iter().enumerate() { + let f_p = f_p.pp(index + 1); + for r_index in 0..cnf.num_layers { + let cnf = if r_index == 0 { + cnf + } else { + MBConvConfig { + input_channels: cnf.out_channels, + stride: 1, + ..cnf + } + }; + blocks.push(MBConv::new(f_p.pp(r_index), cnf)?) + } + } + let final_cna = + ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?; + let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?; + Ok(Self { + init_cna, + blocks, + final_cna, + classifier, + }) + } +} + +impl Module for EfficientNet { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = self.init_cna.forward(xs)?; + for block in self.blocks.iter() { + xs = block.forward(&xs)? + } + let xs = self.final_cna.forward(&xs)?; + // Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1) + let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?; + self.classifier.forward(&xs) + } +} diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs new file mode 100644 index 00000000..6ede136a --- /dev/null +++ b/candle-transformers/src/models/falcon.rs @@ -0,0 +1,484 @@ +use candle::{DType, Device, Result, Tensor, D}; +use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; + +const MAX_SEQ_LEN: usize = 5000; + +fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> { + let weight = vb.get((size2, size1), "weight")?; + let bias = if bias { + Some(vb.get(size2, "bias")?) + } else { + None + }; + Ok(Linear::new(weight, bias)) +} + +fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { + let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { + (Ok(weight), Ok(bias)) => (weight, bias), + (Err(err), _) | (_, Err(err)) => { + if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { + (weight, bias) + } else { + return Err(err); + } + } + }; + Ok(LayerNorm::new(weight, bias, eps)) +} + +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; + Ok(Embedding::new(embeddings, hidden_size)) +} + +// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py +#[derive(Debug)] +pub struct Config { + pub vocab_size: usize, + pub hidden_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub layer_norm_epsilon: f64, + pub initializer_range: f64, + pub use_cache: bool, + pub bos_token_id: u32, + pub eos_token_id: u32, + pub hidden_dropout: f64, + pub attention_dropout: f64, + pub n_head_kv: Option<usize>, + pub alibi: bool, + pub new_decoder_architecture: bool, + pub multi_query: bool, + pub parallel_attn: bool, + pub bias: bool, +} + +impl Default for Config { + fn default() -> Self { + Self { + vocab_size: 65024, + hidden_size: 4544, + num_hidden_layers: 32, + num_attention_heads: 71, + layer_norm_epsilon: 1e-5, + initializer_range: 0.02, + use_cache: true, + bos_token_id: 11, + eos_token_id: 11, + hidden_dropout: 0.0, + attention_dropout: 0.0, + n_head_kv: None, + alibi: false, + new_decoder_architecture: false, + multi_query: true, + parallel_attn: true, + bias: false, + } + } +} + +impl Config { + pub fn validate(&self) -> Result<()> { + if self.alibi { + candle::bail!("alibi is not supported"); + } + if self.new_decoder_architecture { + candle::bail!("new_decoder_architecture is not supported"); + } + if self.n_head_kv.is_some() { + candle::bail!("n_head_kv is not supported"); + } + Ok(()) + } + + // https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json + pub fn falcon7b() -> Self { + // This is currently on par with the defaults, the defaults come from the Python default + // arguments for the config initialization whereas the following come from the json config. + Self { + vocab_size: 65024, + hidden_size: 4544, + num_hidden_layers: 32, + num_attention_heads: 71, + layer_norm_epsilon: 1e-5, + initializer_range: 0.02, + use_cache: true, + bos_token_id: 11, + eos_token_id: 11, + hidden_dropout: 0., + attention_dropout: 0., + n_head_kv: None, + alibi: false, + new_decoder_architecture: false, + multi_query: true, + parallel_attn: true, + bias: false, + } + } + + fn head_dim(&self) -> usize { + self.hidden_size / self.num_attention_heads + } + + fn rotary(&self) -> bool { + !self.alibi + } +} + +fn rotate_half(x: &Tensor) -> Result<Tensor> { + let l = x.dim(D::Minus1)?; + let x1 = x.narrow(D::Minus1, 0, l / 2)?; + let x2 = x.narrow(D::Minus1, l / 2, l - l / 2)?; + let x21 = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; + Ok(x21) +} + +#[derive(Debug)] +struct FalconRotaryEmbedding { + inv_freq: Tensor, + cache: Option<(usize, Tensor, Tensor)>, +} + +impl FalconRotaryEmbedding { + fn load(device: &Device, cfg: &Config) -> Result<Self> { + let head_dim = cfg.head_dim(); + let inv_freq: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32)) + .collect(); + Ok(Self { + inv_freq: Tensor::new(inv_freq.as_slice(), device)?, + cache: None, + }) + } + + fn cos_sin( + &mut self, + seq_len: usize, + device: &Device, + dtype: DType, + ) -> Result<(Tensor, Tensor)> { + match &self.cache { + Some((s, cos, sin)) if *s == seq_len => { + return Ok((cos.clone(), sin.clone())); + } + _ => {} + } + let t = Tensor::arange(0, seq_len as u32, device)?.to_dtype(dtype)?; + let inv_freq = self.inv_freq.to_dtype(dtype)?; + let freqs = t.unsqueeze(1)?.matmul(&inv_freq.unsqueeze(0)?)?; + let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; + let cos = emb.cos()?; + let sin = emb.sin()?; + self.cache = Some((seq_len, cos.clone(), sin.clone())); + Ok((cos, sin)) + } + + fn forward( + &mut self, + query: &Tensor, + key: &Tensor, + past_kv_len: usize, + ) -> Result<(Tensor, Tensor)> { + let (_batch, seq_len, _head_dim) = query.dims3()?; + let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?; + let cos = cos.narrow(0, past_kv_len, seq_len)?; + let sin = sin.narrow(0, past_kv_len, seq_len)?; + let qs = (query.broadcast_mul(&cos)? + &rotate_half(query)?.broadcast_mul(&sin)?)?; + let ks = (key.broadcast_mul(&cos)? + &rotate_half(key)?.broadcast_mul(&sin)?)?; + Ok((qs, ks)) + } +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +#[derive(Debug)] +struct FalconAttention { + query_key_value: Linear, + dense: Linear, + maybe_rotary: Option<FalconRotaryEmbedding>, + kv_cache: Option<(Tensor, Tensor)>, + inv_norm_factor: f64, + multi_query: bool, + use_cache: bool, + num_heads: usize, + head_dim: usize, + n_head_kv: usize, +} + +impl FalconAttention { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let maybe_rotary = if cfg.rotary() { + let rotary = FalconRotaryEmbedding::load(vb.device(), cfg)?; + Some(rotary) + } else { + None + }; + let head_dim = cfg.head_dim(); + let hidden_size = cfg.hidden_size; + let qkv_out_dim = if cfg.multi_query { + hidden_size + 2 * head_dim + } else { + 3 * hidden_size + }; + let query_key_value = linear(hidden_size, qkv_out_dim, cfg.bias, vb.pp("query_key_value"))?; + let dense = linear(hidden_size, hidden_size, cfg.bias, vb.pp("dense"))?; + Ok(Self { + query_key_value, + dense, + maybe_rotary, + kv_cache: None, + inv_norm_factor: 1. / (head_dim as f64).sqrt(), + multi_query: cfg.multi_query, + use_cache: cfg.use_cache, + num_heads: cfg.num_attention_heads, + n_head_kv: cfg.n_head_kv.unwrap_or(1), + head_dim, + }) + } + + fn split_heads(&self, fused_qkv: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { + let (b_sz, seq_len, _) = fused_qkv.dims3()?; + if !self.multi_query { + let fused_qkv = fused_qkv.reshape((b_sz, seq_len, self.num_heads, 3, self.head_dim))?; + let q = fused_qkv.narrow(D::Minus2, 0, 1)?.squeeze(D::Minus2)?; + let k = fused_qkv.narrow(D::Minus2, 1, 1)?.squeeze(D::Minus2)?; + let v = fused_qkv.narrow(D::Minus2, 2, 1)?.squeeze(D::Minus2)?; + Ok((q, k, v)) + } else { + let fused_qkv = + fused_qkv.reshape((b_sz, seq_len, self.num_heads + 2, self.head_dim))?; + let d = fused_qkv.dim(D::Minus2)?; + let q = fused_qkv.narrow(D::Minus2, 0, d - 2)?; + let k = fused_qkv.narrow(D::Minus2, d - 2, 1)?; + let v = fused_qkv.narrow(D::Minus2, d - 1, 1)?; + Ok((q, k, v)) + } + } + + fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> { + let fused_qkv = self.query_key_value.forward(x)?; + let head_dim = self.head_dim; + let (query, key, value) = self.split_heads(&fused_qkv)?; + let (b_sz, seq_len, _, _) = query.dims4()?; + let query = query + .transpose(1, 2)? + .reshape((b_sz * self.num_heads, seq_len, head_dim))?; + let key = key + .transpose(1, 2)? + .reshape((b_sz * self.n_head_kv, seq_len, head_dim))?; + let value = value + .transpose(1, 2)? + .reshape((b_sz * self.n_head_kv, seq_len, head_dim))?; + let (query, key) = if let Some(r) = &mut self.maybe_rotary { + r.forward(&query, &key, past_kv_len)? + } else { + (query, key) + }; + let (mut key, mut value) = (key, value); + let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?.to_dtype(query.dtype())?; + if self.use_cache { + if let Some((cache_k, cache_v)) = &self.kv_cache { + // TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for + // arbitrarily large sizes. + key = Tensor::cat(&[cache_k, &key], 1)?.contiguous()?; + value = Tensor::cat(&[cache_v, &value], 1)?.contiguous()?; + } + self.kv_cache = Some((key.clone(), value.clone())) + } + let query = query.reshape((b_sz * self.num_heads, seq_len, head_dim))?; + let all_len = past_kv_len + seq_len; + let key = key.reshape((b_sz * self.n_head_kv, all_len, head_dim))?; + let value = value.reshape((b_sz * self.n_head_kv, all_len, head_dim))?; + + let (key, value) = if self.n_head_kv == 1 { + ( + key.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?, + value.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?, + ) + } else { + (key, value) + }; + + // Only handle the case where alibi is None here, and non-flash attention. + let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?; + let attention_scores = candle_nn::ops::softmax( + &attention_scores + .broadcast_add(&mask.squeeze(1)?)? + .to_dtype(DType::F32)?, + D::Minus1, + )? + .to_dtype(x.dtype())?; + let attn_output = attention_scores + .matmul(&value)? + .reshape((b_sz, self.num_heads, seq_len, head_dim))? + .transpose(1, 2)? + .reshape((b_sz, seq_len, self.num_heads * head_dim))?; + let attn_output = self.dense.forward(&attn_output)?; + Ok(attn_output) + } +} + +#[derive(Debug)] +struct FalconMlp { + dense_h_to_4h: Linear, + dense_4h_to_h: Linear, +} + +impl FalconMlp { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let h = cfg.hidden_size; + let b = cfg.bias; + let dense_h_to_4h = linear(h, 4 * h, b, vb.pp("dense_h_to_4h"))?; + let dense_4h_to_h = linear(4 * h, h, b, vb.pp("dense_4h_to_h"))?; + Ok(Self { + dense_h_to_4h, + dense_4h_to_h, + }) + } + + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let x = self.dense_h_to_4h.forward(x)?.gelu()?; + let x = self.dense_4h_to_h.forward(&x)?; + Ok(x) + } +} + +#[derive(Debug)] +struct FalconDecoderLayer { + inp_layernorm: LayerNorm, + self_attention: FalconAttention, + post_attention_layernorm: Option<LayerNorm>, + mlp: FalconMlp, + parallel_attn: bool, +} + +impl FalconDecoderLayer { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let mlp = FalconMlp::load(vb.pp("mlp"), cfg)?; + let inp_layernorm = layer_norm( + cfg.hidden_size, + cfg.layer_norm_epsilon, + vb.pp("input_layernorm"), + )?; + let self_attention = FalconAttention::load(vb.pp("self_attention"), cfg)?; + let post_attention_layernorm = if cfg.parallel_attn { + None + } else { + let ln = layer_norm( + cfg.hidden_size, + cfg.layer_norm_epsilon, + vb.pp("post_attention_layernorm"), + )?; + Some(ln) + }; + Ok(Self { + inp_layernorm, + self_attention, + post_attention_layernorm, + mlp, + parallel_attn: cfg.parallel_attn, + }) + } + + fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> { + let residual = x.clone(); + let ln_attn = self.inp_layernorm.forward(x)?; + let attn_output = self.self_attention.forward(&ln_attn, mask, past_kv_len)?; + let (residual, ln_mlp) = match &self.post_attention_layernorm { + None => (residual, ln_attn), + Some(pal) => { + // This should include some dropout. + let residual = (&attn_output + &residual)?; + let ln_mlp = pal.forward(&residual)?; + (residual, ln_mlp) + } + }; + let mlp_output = self.mlp.forward(&ln_mlp)?; + + let mlp_output = if self.parallel_attn { + (mlp_output + attn_output)? + } else { + mlp_output + }; + let output = (mlp_output + residual)?; + Ok(output) + } +} + +#[derive(Debug)] +pub struct Falcon { + word_embeddings: Embedding, + blocks: Vec<FalconDecoderLayer>, + ln_f: LayerNorm, + lm_head: Linear, + config: Config, +} + +fn make_causal_mask(t: usize) -> Result<Tensor> { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; + Ok(mask) +} + +fn prepare_attn_mask(b_sz: usize, seq_len: usize) -> Result<Tensor> { + // let mask = Tensor::ones((b_sz, seq_len), DType::U32, &Device::Cpu)?; + let mask = make_causal_mask(seq_len)?; + let mask = mask.broadcast_as((b_sz, 1, seq_len, seq_len))?; + Ok(mask) +} + +impl Falcon { + pub fn config(&self) -> &Config { + &self.config + } + + pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> { + let word_embeddings = embedding( + cfg.vocab_size, + cfg.hidden_size, + vb.pp("transformer.word_embeddings"), + )?; + let blocks = (0..cfg.num_hidden_layers) + .map(|i| FalconDecoderLayer::load(vb.pp(&format!("transformer.h.{i}")), &cfg)) + .collect::<Result<Vec<_>>>()?; + let ln_f = layer_norm( + cfg.hidden_size, + cfg.layer_norm_epsilon, + vb.pp("transformer.ln_f"), + )?; + let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?; + Ok(Self { + word_embeddings, + blocks, + ln_f, + lm_head, + config: cfg, + }) + } + + pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { + let (b_sz, seq_len) = input_ids.dims2()?; + let mut hidden_state = self.word_embeddings.forward(input_ids)?; + let past_kv_len = match &self.blocks[0].self_attention.kv_cache { + Some((k, _)) => k.dim(1)?, + None => 0, + }; + let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?; + for block in self.blocks.iter_mut() { + hidden_state = block.forward(&hidden_state, &causal_mask, past_kv_len)?; + } + let hidden_state = self.ln_f.forward(&hidden_state)?; + let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?; + let logits = self.lm_head.forward(&hidden_state)?.squeeze(1)?; + Ok(logits) + } +} diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs new file mode 100644 index 00000000..eed4df5e --- /dev/null +++ b/candle-transformers/src/models/llama.rs @@ -0,0 +1,446 @@ +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Embedding, Module, VarBuilder}; +use serde::Deserialize; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +pub const MAX_SEQ_LEN: usize = 4096; + +#[derive(Deserialize)] +pub struct LlamaConfig { + pub hidden_size: usize, + pub intermediate_size: usize, + pub vocab_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: Option<usize>, + pub rms_norm_eps: f64, + #[serde(default = "default_rope")] + pub rope_theta: f32, +} + +fn default_rope() -> f32 { + 10_000.0 +} + +impl LlamaConfig { + pub fn into_config(self, use_flash_attn: bool) -> Config { + Config { + hidden_size: self.hidden_size, + intermediate_size: self.intermediate_size, + vocab_size: self.vocab_size, + num_hidden_layers: self.num_hidden_layers, + num_attention_heads: self.num_attention_heads, + num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads), + rms_norm_eps: self.rms_norm_eps, + rope_theta: self.rope_theta, + use_flash_attn, + } + } +} + +pub struct Config { + pub hidden_size: usize, + pub intermediate_size: usize, + pub vocab_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub use_flash_attn: bool, + pub rms_norm_eps: f64, + pub rope_theta: f32, +} + +impl Config { + pub fn config_7b_v1(use_flash_attn: bool) -> Self { + Self { + hidden_size: 4096, + intermediate_size: 11008, + vocab_size: 32000, + num_hidden_layers: 32, + num_attention_heads: 32, + num_key_value_heads: 32, + use_flash_attn, + rms_norm_eps: 1e-6, + rope_theta: 10_000.0, + } + } + + pub fn config_7b_v2(use_flash_attn: bool) -> Self { + Self { + hidden_size: 4096, + intermediate_size: 11008, + vocab_size: 32000, + num_hidden_layers: 32, + num_attention_heads: 32, + num_key_value_heads: 32, + use_flash_attn, + rms_norm_eps: 1e-5, + rope_theta: 10_000.0, + } + } +} + +// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting +// model. +#[derive(Debug)] +pub struct Linear { + inner: candle_nn::Linear, + span: tracing::Span, +} + +impl Linear { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +#[derive(Clone)] +pub struct Cache { + masks: Arc<Mutex<HashMap<usize, Tensor>>>, + pub use_kv_cache: bool, + #[allow(clippy::type_complexity)] + kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>, + cos: Tensor, + sin: Tensor, + device: Device, +} + +impl Cache { + pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> { + // precompute freqs_cis + let n_elem = config.hidden_size / config.num_attention_heads; + let theta: Vec<_> = (0..n_elem) + .step_by(2) + .map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), device)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + // This is different from the paper, see: + // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 + let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?; + let cos = idx_theta.cos()?.to_dtype(dtype)?; + let sin = idx_theta.sin()?.to_dtype(dtype)?; + Ok(Self { + masks: Arc::new(Mutex::new(HashMap::new())), + use_kv_cache, + kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])), + device: device.clone(), + cos, + sin, + }) + } + + fn mask(&self, t: usize) -> Result<Tensor> { + let mut masks = self.masks.lock().unwrap(); + if let Some(mask) = masks.get(&t) { + Ok(mask.clone()) + } else { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; + masks.insert(t, mask.clone()); + Ok(mask) + } + } +} + +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { + let span = tracing::span!(tracing::Level::TRACE, "linear"); + let inner = candle_nn::linear_no_bias(size1, size2, vb)?; + Ok(Linear { inner, span }) +} + +fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> { + let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?; + Ok(Embedding::new(embeddings, cfg.hidden_size)) +} + +struct RmsNorm { + inner: candle_nn::RmsNorm, + span: tracing::Span, +} + +impl RmsNorm { + fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); + let inner = candle_nn::rms_norm(size, eps, vb)?; + Ok(Self { inner, span }) + } + + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +struct CausalSelfAttention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + num_attention_heads: usize, + num_key_value_heads: usize, + head_dim: usize, + cache: Cache, + use_flash_attn: bool, + span: tracing::Span, + span_rot: tracing::Span, +} + +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result<Tensor> { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> { + unimplemented!("compile with '--features flash-attn'") +} + +impl CausalSelfAttention { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { + let _enter = self.span_rot.enter(); + let (b_sz, _, seq_len, hidden_size) = x.dims4()?; + let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; + let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; + let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?; + let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?; + let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?; + let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?; + let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; + let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?; + Ok(rope) + } + + fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { + let _enter = self.span.enter(); + let (b_sz, seq_len, hidden_size) = x.dims3()?; + let q = self.q_proj.forward(x)?; + let k = self.k_proj.forward(x)?; + let v = self.v_proj.forward(x)?; + + let q = q + .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? + .transpose(1, 2)?; + let mut v = v + .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? + .transpose(1, 2)?; + + let q = self.apply_rotary_emb(&q, index_pos)?; + let mut k = self.apply_rotary_emb(&k, index_pos)?; + + if self.cache.use_kv_cache { + let mut cache = self.cache.kvs.lock().unwrap(); + if let Some((cache_k, cache_v)) = &cache[block_idx] { + k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?; + v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?; + let k_seq_len = k.dims()[1]; + if k_seq_len > MAX_SEQ_LEN { + k = k + .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .contiguous()? + } + let v_seq_len = v.dims()[1]; + if v_seq_len > 2 * MAX_SEQ_LEN { + v = v + .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? + .contiguous()? + } + } + cache[block_idx] = Some((k.clone(), v.clone())) + } + + let k = self.repeat_kv(k)?; + let v = self.repeat_kv(v)?; + + let y = if self.use_flash_attn { + // flash-attn expects (b_sz, seq_len, nheads, head_dim) + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); + flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)? + } else { + let in_dtype = q.dtype(); + let q = q.to_dtype(DType::F32)?; + let k = k.to_dtype(DType::F32)?; + let v = v.to_dtype(DType::F32)?; + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; + let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = candle_nn::ops::softmax(&att, D::Minus1)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)? + }; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?; + let y = self.o_proj.forward(&y)?; + Ok(y) + } + + fn repeat_kv(&self, x: Tensor) -> Result<Tensor> { + let n_rep = self.num_attention_heads / self.num_key_value_heads; + if n_rep == 1 { + Ok(x) + } else { + let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?; + let x = x + .unsqueeze(2)? + .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? + .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?; + Ok(x) + } + } + + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let size_in = cfg.hidden_size; + let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads; + let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads; + let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?; + let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?; + let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?; + let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_attention_heads: cfg.num_attention_heads, + num_key_value_heads: cfg.num_key_value_heads, + head_dim: cfg.hidden_size / cfg.num_attention_heads, + cache: cache.clone(), + use_flash_attn: cfg.use_flash_attn, + span, + span_rot, + }) + } +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +struct Mlp { + c_fc1: Linear, + c_fc2: Linear, + c_proj: Linear, + span: tracing::Span, +} + +impl Mlp { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; + self.c_proj.forward(&x) + } + + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "mlp"); + let h_size = cfg.hidden_size; + let i_size = cfg.intermediate_size; + let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?; + let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?; + let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?; + Ok(Self { + c_fc1, + c_fc2, + c_proj, + span, + }) + } +} + +struct Block { + rms_1: RmsNorm, + attn: CausalSelfAttention, + rms_2: RmsNorm, + mlp: Mlp, + span: tracing::Span, +} + +impl Block { + fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { + let _enter = self.span.enter(); + let residual = x; + let x = self.rms_1.forward(x)?; + let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; + let residual = &x; + let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; + Ok(x) + } + + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "block"); + let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; + let mlp = Mlp::load(vb.pp("mlp"), cfg)?; + let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let rms_2 = RmsNorm::load( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + rms_1, + attn, + rms_2, + mlp, + span, + }) + } +} + +pub struct Llama { + wte: Embedding, + blocks: Vec<Block>, + ln_f: RmsNorm, + lm_head: Linear, +} + +impl Llama { + pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { + let (_b_sz, seq_len) = x.dims2()?; + let mut x = self.wte.forward(x)?; + for (block_idx, block) in self.blocks.iter().enumerate() { + x = block.forward(&x, index_pos, block_idx)?; + } + let x = self.ln_f.forward(&x)?; + let x = x.i((.., seq_len - 1, ..))?; + let logits = self.lm_head.forward(&x)?; + logits.to_dtype(DType::F32) + } + + pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { + let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; + let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; + let blocks: Vec<_> = (0..cfg.num_hidden_layers) + .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap()) + .collect(); + + Ok(Self { + wte, + blocks, + ln_f, + lm_head, + }) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 8b137891..d783a2c6 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -1 +1,13 @@ - +pub mod bert; +pub mod bigcode; +pub mod dinov2; +pub mod efficientnet; +pub mod falcon; +pub mod llama; +pub mod quantized_llama; +pub mod quantized_t5; +pub mod segment_anything; +pub mod stable_diffusion; +pub mod t5; +pub mod whisper; +pub mod wuerstchen; diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs new file mode 100644 index 00000000..2988b0fb --- /dev/null +++ b/candle-transformers/src/models/quantized_llama.rs @@ -0,0 +1,371 @@ +use std::collections::HashMap; + +use candle::quantized::QTensor; +use candle::quantized::{ggml_file, gguf_file}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Embedding, Module}; + +pub const MAX_SEQ_LEN: usize = 4096; + +struct RmsNorm { + inner: candle_nn::LayerNorm, + span: tracing::Span, +} + +impl RmsNorm { + fn new(scale: QTensor, eps: f32) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); + let scale = scale.dequantize(&Device::Cpu)?; + let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64); + Ok(Self { inner, span }) + } + + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +// QMatMul wrapper adding some tracing. +struct QMatMul { + inner: candle::quantized::QMatMul, + span: tracing::Span, +} + +impl QMatMul { + fn from_qtensor(qtensor: QTensor) -> Self { + let inner = candle::quantized::QMatMul::from_qtensor(qtensor); + let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); + Self { inner, span } + } + + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +struct LayerWeights { + attention_wq: QMatMul, + attention_wk: QMatMul, + attention_wv: QMatMul, + attention_wo: QMatMul, + attention_norm: RmsNorm, + feed_forward_w1: QMatMul, + feed_forward_w2: QMatMul, + feed_forward_w3: QMatMul, + ffn_norm: RmsNorm, + n_head: usize, + n_kv_head: usize, + head_dim: usize, + cos: Tensor, + sin: Tensor, + kv_cache: Option<(Tensor, Tensor)>, + span_attn: tracing::Span, + span_rot: tracing::Span, + span_mlp: tracing::Span, +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +impl LayerWeights { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { + let _enter = self.span_rot.enter(); + let (b_sz, n_head, seq_len, n_embd) = x.dims4()?; + let cos = self + .cos + .narrow(0, index_pos, seq_len)? + .reshape((seq_len, n_embd / 2, 1))?; + let sin = self + .sin + .narrow(0, index_pos, seq_len)? + .reshape((seq_len, n_embd / 2, 1))?; + let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; + let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; + // This mimics the llama.cpp behavior. + // https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105 + // The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension. + // The resulting y0 and y1 are also interleaved with: + // y0 = x0*cos - x1*sin + // y1 = x0*sin + x1*cos + let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; + let x0 = x.narrow(D::Minus1, 0, 1)?; + let x1 = x.narrow(D::Minus1, 1, 1)?; + let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; + let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; + let rope = Tensor::cat(&[y0, y1], D::Minus1)?; + let rope = rope.flatten_from(D::Minus2)?; + Ok(rope) + } + + fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> { + let _enter = self.span_attn.enter(); + let (b_sz, seq_len, n_embd) = x.dims3()?; + let q = self.attention_wq.forward(x)?; + let k = self.attention_wk.forward(x)?; + let v = self.attention_wv.forward(x)?; + + let q = q + .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)?; + + let q = self.apply_rotary_emb(&q, index_pos)?; + let k = self.apply_rotary_emb(&k, index_pos)?; + + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((k_cache, v_cache)) => { + if index_pos == 0 { + (k, v) + } else { + let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?; + let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?; + (k, v) + } + } + }; + self.kv_cache = Some((k.clone(), v.clone())); + + // Support for MQA, useful for 70B models. + let k = self.repeat_kv(k)?; + let v = self.repeat_kv(v)?; + + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let mask = mask.broadcast_as(att.shape())?; + let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = candle_nn::ops::softmax_last_dim(&att)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let y = att.matmul(&v.contiguous()?)?; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; + let y = self.attention_wo.forward(&y)?; + Ok(y) + } + + fn repeat_kv(&self, x: Tensor) -> Result<Tensor> { + let n_rep = self.n_head / self.n_kv_head; + if n_rep == 1 { + Ok(x) + } else { + let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?; + let x = x + .unsqueeze(2)? + .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? + .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?; + Ok(x) + } + } +} + +pub struct ModelWeights { + tok_embeddings: Embedding, + layers: Vec<LayerWeights>, + norm: RmsNorm, + output: QMatMul, + masks: HashMap<usize, Tensor>, + span: tracing::Span, + span_output: tracing::Span, +} + +fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> { + let theta: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok((cos, sin)) +} + +impl ModelWeights { + pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> { + let cpu = &Device::Cpu; + let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; + let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?; + let tok_embeddings = ct.remove("tok_embeddings.weight")?; + let tok_embeddings = tok_embeddings.dequantize(cpu)?; + let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?; + let output = ct.remove("output.weight")?; + let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); + for layer_idx in 0..ct.hparams.n_layer { + let prefix = format!("layers.{layer_idx}"); + let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?; + let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?; + let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?; + let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?; + let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; + let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; + let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; + let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?; + let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?; + let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); + layers.push(LayerWeights { + attention_wq: QMatMul::from_qtensor(attention_wq), + attention_wk: QMatMul::from_qtensor(attention_wk), + attention_wv: QMatMul::from_qtensor(attention_wv), + attention_wo: QMatMul::from_qtensor(attention_wo), + attention_norm: RmsNorm::new(attention_norm, 1e-5)?, + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), + ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?, + n_head: ct.hparams.n_head as usize, + n_kv_head: ct.hparams.n_head as usize / gqa, + head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, + cos: cos.clone(), + sin: sin.clone(), + kv_cache: None, + span_attn, + span_rot, + span_mlp, + }) + } + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize), + layers, + norm, + output: QMatMul::from_qtensor(output), + masks: HashMap::new(), + span, + span_output, + }) + } + + pub fn from_gguf<R: std::io::Seek + std::io::Read>( + ct: gguf_file::Content, + reader: &mut R, + ) -> Result<Self> { + let cpu = &Device::Cpu; + let md_get = |s: &str| match ct.metadata.get(s) { + None => candle::bail!("cannot find {s} in metadata"), + Some(v) => Ok(v), + }; + + // Parameter extraction from metadata. + let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; + let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; + let block_count = md_get("llama.block_count")?.to_u32()? as usize; + let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; + let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; + // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. + let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?; + + let rope_freq_base = md_get("llama.rope.freq_base") + .and_then(|m| m.to_f32()) + .unwrap_or(10000f32); + let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?; + + let tok_embeddings = ct.tensor(reader, "token_embd.weight")?; + let tok_embeddings = tok_embeddings.dequantize(cpu)?; + let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?; + let output = ct.tensor(reader, "output.weight")?; + let mut layers = Vec::with_capacity(block_count); + for layer_idx in 0..block_count { + let prefix = format!("blk.{layer_idx}"); + let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?; + let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?; + let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?; + let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?; + let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?; + let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?; + let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?; + let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?; + let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?; + let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); + layers.push(LayerWeights { + attention_wq: QMatMul::from_qtensor(attention_wq), + attention_wk: QMatMul::from_qtensor(attention_wk), + attention_wv: QMatMul::from_qtensor(attention_wv), + attention_wo: QMatMul::from_qtensor(attention_wo), + attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?, + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), + ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?, + n_head: head_count, + n_kv_head: head_count_kv, + head_dim: embedding_length / head_count, + cos: cos.clone(), + sin: sin.clone(), + kv_cache: None, + span_attn, + span_rot, + span_mlp, + }) + } + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, embedding_length), + layers, + norm, + output: QMatMul::from_qtensor(output), + masks: HashMap::new(), + span, + span_output, + }) + } + + fn mask(&mut self, t: usize) -> Result<Tensor> { + if let Some(mask) = self.masks.get(&t) { + Ok(mask.clone()) + } else { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; + self.masks.insert(t, mask.clone()); + Ok(mask) + } + } + + pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> { + let (_b_sz, seq_len) = x.dims2()?; + let mask = self.mask(seq_len)?; + let _enter = self.span.enter(); + let mut layer_in = self.tok_embeddings.forward(x)?; + for layer in self.layers.iter_mut() { + let x = layer_in; + let residual = &x; + let x = layer.attention_norm.forward(&x)?; + let attn = layer.forward_attn(&x, &mask, index_pos)?; + let x = (attn + residual)?; + + // MLP + let _enter = layer.span_mlp.enter(); + let residual = &x; + let x = layer.ffn_norm.forward(&x)?; + let w1 = layer.feed_forward_w1.forward(&x)?; + let w3 = layer.feed_forward_w3.forward(&x)?; + let mlp = layer + .feed_forward_w2 + .forward(&(candle_nn::ops::silu(&w1)? * w3)?)?; + layer_in = (mlp + residual)?; + } + let x = self.norm.forward(&layer_in)?; + let x = x.i((.., seq_len - 1, ..))?; + let _enter = self.span_output.enter(); + self.output.forward(&x) + } +} diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs new file mode 100644 index 00000000..a10c3b80 --- /dev/null +++ b/candle-transformers/src/models/quantized_t5.rs @@ -0,0 +1,884 @@ +// T5 Text Model, quantized version +// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py + +use candle::quantized::QTensor; +use candle::{DType, Device, Module, Result, Shape, Tensor, D}; +use candle_nn::Activation; +use serde::Deserialize; +use std::sync::Arc; + +// VarBuilder specialized for QTensors +pub struct VarBuilder { + data: Arc<std::collections::HashMap<String, Arc<QTensor>>>, + path: Vec<String>, + device: Device, +} + +impl VarBuilder { + pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> { + let mut file = std::fs::File::open(p)?; + let content = candle::quantized::gguf_file::Content::read(&mut file)?; + let mut data = std::collections::HashMap::new(); + for tensor_name in content.tensor_infos.keys() { + let tensor = content.tensor(&mut file, tensor_name)?; + data.insert(tensor_name.to_string(), Arc::new(tensor)); + } + Ok(Self { + data: Arc::new(data), + path: Vec::new(), + device: Device::Cpu, + }) + } + + fn pp<S: ToString>(&self, s: S) -> Self { + let mut path = self.path.clone(); + path.push(s.to_string()); + Self { + data: self.data.clone(), + path, + device: self.device.clone(), + } + } + + fn path(&self, tensor_name: &str) -> String { + if self.path.is_empty() { + tensor_name.to_string() + } else { + [&self.path.join("."), tensor_name].join(".") + } + } + + fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Arc<QTensor>> { + let path = self.path(name); + match self.data.get(&path) { + None => { + candle::bail!("cannot find tensor {name}") + } + Some(qtensor) => { + let shape = s.into(); + if qtensor.shape() != &shape { + candle::bail!( + "shape mismatch for {name}, got {:?}, expected {shape:?}", + qtensor.shape() + ) + } + Ok(qtensor.clone()) + } + } + } +} + +#[derive(Debug)] +struct Embedding { + inner: candle_nn::Embedding, + span: tracing::Span, +} + +impl Embedding { + fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> { + let embeddings = vb.get((d1, d2), "weight")?.dequantize(&vb.device)?; + let inner = candle_nn::Embedding::new(embeddings, d2); + let span = tracing::span!(tracing::Level::TRACE, "embedding"); + Ok(Self { inner, span }) + } + + fn embeddings(&self) -> &Tensor { + self.inner.embeddings() + } +} + +impl Module for Embedding { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +// QMatMul wrapper adding some tracing. +struct QMatMul { + inner: candle::quantized::QMatMul, + span: tracing::Span, +} + +impl QMatMul { + fn new(out_dim: usize, in_dim: usize, vb: VarBuilder) -> Result<Self> { + let ws = vb.get((in_dim, out_dim), "weight")?; + let inner = candle::quantized::QMatMul::from_arc(ws); + let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); + Ok(Self { inner, span }) + } +} + +impl Module for QMatMul { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +impl std::fmt::Debug for QMatMul { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "QMatMul") + } +} + +fn default_relative_attention_max_distance() -> usize { + 128 +} + +fn default_is_decoder() -> bool { + false +} + +fn default_use_cache() -> bool { + true +} + +fn default_tie_word_embeddings() -> bool { + true +} + +fn get_mask(size: usize, device: &Device) -> Result<Tensor> { + let mask: Vec<_> = (0..size) + .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) + .collect(); + Tensor::from_slice(&mask, (size, size), device) +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + vocab_size: usize, + d_model: usize, + d_kv: usize, + d_ff: usize, + num_layers: usize, + num_decoder_layers: Option<usize>, + num_heads: usize, + relative_attention_num_buckets: usize, + #[serde(default = "default_relative_attention_max_distance")] + relative_attention_max_distance: usize, + dropout_rate: f64, + layer_norm_epsilon: f64, + initializer_factor: f64, + #[serde(default)] + feed_forward_proj: Activation, + #[serde(default = "default_tie_word_embeddings")] + tie_word_embeddings: bool, + #[serde(default = "default_is_decoder")] + is_decoder: bool, + is_encoder_decoder: bool, + #[serde(default = "default_use_cache")] + pub use_cache: bool, + pub pad_token_id: usize, + pub eos_token_id: usize, +} + +impl Default for Config { + fn default() -> Self { + Self { + vocab_size: 32128, + d_model: 512, + d_kv: 64, + d_ff: 2048, + num_layers: 6, + num_decoder_layers: None, + num_heads: 8, + relative_attention_num_buckets: 32, + relative_attention_max_distance: 128, + dropout_rate: 0.1, + layer_norm_epsilon: 1e-6, + initializer_factor: 1.0, + feed_forward_proj: Activation::Relu, + tie_word_embeddings: true, + is_decoder: false, + is_encoder_decoder: true, + use_cache: true, + pad_token_id: 0, + eos_token_id: 1, + } + } +} + +#[derive(Debug)] +struct T5LayerNorm { + weight: Tensor, + variance_epsilon: f64, + span: tracing::Span, +} + +impl T5LayerNorm { + fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let weight = vb.get(h, "weight")?.dequantize(&vb.device)?; + Ok(Self { + weight, + variance_epsilon: eps, + span: tracing::span!(tracing::Level::TRACE, "layer-norm"), + }) + } +} + +impl Module for T5LayerNorm { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let dtype = xs.dtype(); + let xs_f32 = xs.to_dtype(DType::F32)?; + // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?; + let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?; + let xs = xs.to_dtype(dtype)?; + let xs = xs.broadcast_mul(&self.weight)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5DenseActDense { + wi: QMatMul, + wo: QMatMul, + act: Activation, + span: tracing::Span, +} + +impl T5DenseActDense { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let wi = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi"))?; + let wo = QMatMul::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; + Ok(Self { + wi, + wo, + act: Activation::Relu, + span: tracing::span!(tracing::Level::TRACE, "dense-act-dense"), + }) + } +} + +impl Module for T5DenseActDense { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let xs = self.wi.forward(xs)?; + let xs = self.act.forward(&xs)?; + let xs = self.wo.forward(&xs)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5DenseGatedActDense { + wi_0: QMatMul, + wi_1: QMatMul, + wo: QMatMul, + act: Activation, + span: tracing::Span, +} + +impl T5DenseGatedActDense { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let wi_0 = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?; + let wi_1 = QMatMul::new(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?; + let wo = QMatMul::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; + Ok(Self { + wi_0, + wi_1, + wo, + act: Activation::NewGelu, + span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"), + }) + } +} + +impl Module for T5DenseGatedActDense { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?; + let hidden_linear = self.wi_1.forward(xs)?; + let xs = hidden_gelu.broadcast_mul(&hidden_linear)?; + let xs = self.wo.forward(&xs)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5LayerFF { + dense_act: Option<T5DenseActDense>, + gated_dense_act: Option<T5DenseGatedActDense>, + layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5LayerFF { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let layer_norm = + T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; + let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu { + ( + None, + Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?), + ) + } else { + ( + Some(T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?), + None, + ) + }; + Ok(Self { + dense_act, + gated_dense_act, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "layer-ff"), + }) + } +} + +impl Module for T5LayerFF { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let ys = self.layer_norm.forward(xs)?; + let ys = match &self.dense_act { + Some(dense_act) => dense_act.forward(&ys)?, + None => self.gated_dense_act.as_ref().unwrap().forward(&ys)?, + }; + let xs = (xs + ys)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5Attention { + q: QMatMul, + k: QMatMul, + v: QMatMul, + o: QMatMul, + n_heads: usize, + d_kv: usize, + relative_attention_bias: Option<Embedding>, + relative_attention_num_buckets: usize, + relative_attention_max_distance: usize, + inner_dim: usize, + use_cache: bool, + kv_cache: Option<(Tensor, Tensor)>, + span: tracing::Span, + span_cache: tracing::Span, + span_mm: tracing::Span, + span_sm: tracing::Span, +} + +impl T5Attention { + fn load( + has_relative_attention_bias: bool, + decoder: bool, + vb: VarBuilder, + cfg: &Config, + ) -> Result<Self> { + let inner_dim = cfg.num_heads * cfg.d_kv; + let q = QMatMul::new(cfg.d_model, inner_dim, vb.pp("q"))?; + let k = QMatMul::new(cfg.d_model, inner_dim, vb.pp("k"))?; + let v = QMatMul::new(cfg.d_model, inner_dim, vb.pp("v"))?; + let o = QMatMul::new(inner_dim, cfg.d_model, vb.pp("o"))?; + let relative_attention_bias = if has_relative_attention_bias { + let emb = Embedding::new( + cfg.relative_attention_num_buckets, + cfg.num_heads, + vb.pp("relative_attention_bias"), + )?; + Some(emb) + } else { + None + }; + Ok(Self { + q, + k, + v, + o, + n_heads: cfg.num_heads, + d_kv: cfg.d_kv, + relative_attention_bias, + relative_attention_num_buckets: cfg.relative_attention_num_buckets, + relative_attention_max_distance: cfg.relative_attention_max_distance, + inner_dim, + use_cache: cfg.use_cache && decoder, + kv_cache: None, + span: tracing::span!(tracing::Level::TRACE, "attention"), + span_cache: tracing::span!(tracing::Level::TRACE, "attention-cache"), + span_mm: tracing::span!(tracing::Level::TRACE, "attention-mm"), + span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"), + }) + } + + fn forward( + &mut self, + xs: &Tensor, + position_bias: Option<&Tensor>, + key_value_states: Option<&Tensor>, + mask: Option<&Tensor>, + ) -> Result<(Tensor, Option<Tensor>)> { + // Performs Self-attention (if key_value_states is None) or attention + // over source sentence (provided by key_value_states). + let _enter = self.span.enter(); + let kv_input = match key_value_states { + None => xs, + Some(key_value_states) => key_value_states, + }; + let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?); + let kv_len = kv_input.dim(1)?; + let q = self.q.forward(xs)?; + let k = self.k.forward(kv_input)?; + let v = self.v.forward(kv_input)?; + let q = q + .reshape((b_sz, q_len, self.n_heads, self.d_kv))? + .transpose(1, 2)? + .contiguous()?; + let mut k = k + .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? + .transpose(1, 2)? + .contiguous()?; + let mut v = v + .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? + .transpose(1, 2)? + .contiguous()?; + + if self.use_cache { + let _enter = self.span_cache.enter(); + if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache { + k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?; + v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?; + }; + self.kv_cache = Some((k.clone(), v.clone())); + }; + // TODO: Use flash_attn. + let scores = { + let _enter = self.span_mm.enter(); + q.matmul(&k.t()?)? + }; + let scores = match mask { + None => scores, + Some(mask) => masked_fill( + &scores, + &mask + .unsqueeze(0)? + .unsqueeze(0)? + .repeat((b_sz, self.n_heads))?, + f32::NEG_INFINITY, + )?, + }; + + let (scores, position_bias) = match position_bias { + Some(position_bias) => ( + scores.broadcast_add(position_bias)?, + Some(position_bias.clone()), + ), + None => match &self.relative_attention_bias { + None => (scores, None), + Some(relative_attention_bias) => { + // This only handles the bidirectional case. + let kv_len = k.dim(2)?; + let (q_start, q_end) = match self.use_cache { + true => ((kv_len - q_len) as u32, kv_len as u32), + false => (0_u32, kv_len as u32), + }; + let num_buckets = self.relative_attention_num_buckets as u32 / 2; + let max_exact = num_buckets / 2; + let relative_position = (q_start..q_end) + .map(|i| { + (0..kv_len as u32) + .map(|j| { + if i < j { + if j - i < max_exact { + j - i + num_buckets + } else { + let b = f32::log( + (j - i) as f32 / max_exact as f32, + self.relative_attention_max_distance as f32 + / max_exact as f32, + ) * (num_buckets - max_exact) as f32; + u32::min( + max_exact + num_buckets + b as u32, + self.relative_attention_num_buckets as u32 - 1, + ) + } + } else if i - j < max_exact { + i - j + } else { + let b = f32::log( + (i - j) as f32 / max_exact as f32, + self.relative_attention_max_distance as f32 + / max_exact as f32, + ) * (num_buckets - max_exact) as f32; + max_exact + b as u32 + } + }) + .collect::<Vec<u32>>() + }) + .collect::<Vec<Vec<_>>>(); + let relative_buckets = Tensor::new(relative_position, q.device())?; + let position_bias = relative_attention_bias + .forward(&relative_buckets)? + .permute((2, 0, 1))? + .unsqueeze(0)?; + (scores.broadcast_add(&position_bias)?, Some(position_bias)) + // TODO: position_bias_masked? + } + }, + }; + + let attn_weights = { + let _enter = self.span_sm.enter(); + candle_nn::ops::softmax(&scores, D::Minus1)? + }; + let attn_output = attn_weights.matmul(&v)?; + let attn_output = attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, self.inner_dim))?; + let attn_output = self.o.forward(&attn_output)?; + Ok((attn_output, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } +} + +#[derive(Debug)] +struct T5LayerSelfAttention { + self_attention: T5Attention, + layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5LayerSelfAttention { + fn load(h: bool, d: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { + let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?; + let layer_norm = + T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; + Ok(Self { + self_attention, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "self-attn"), + }) + } + + fn forward( + &mut self, + xs: &Tensor, + position_bias: Option<&Tensor>, + mask: Option<&Tensor>, + ) -> Result<(Tensor, Option<Tensor>)> { + let _enter = self.span.enter(); + let normed_xs = self.layer_norm.forward(xs)?; + let (ys, position_bias) = + self.self_attention + .forward(&normed_xs, position_bias, None, mask)?; + let ys = (xs + ys)?; + Ok((ys, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.self_attention.clear_kv_cache() + } +} + +#[derive(Debug)] +struct T5LayerCrossAttention { + cross_attention: T5Attention, + layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5LayerCrossAttention { + fn load(decoder: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { + let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?; + let layer_norm = + T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; + Ok(Self { + cross_attention, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "cross-attn"), + }) + } + + fn forward( + &mut self, + hidden_states: &Tensor, + position_bias: Option<&Tensor>, + key_value_states: &Tensor, + ) -> Result<(Tensor, Option<Tensor>)> { + let _enter = self.span.enter(); + let normed_hidden_states = self.layer_norm.forward(hidden_states)?; + let (ys, position_bias) = self.cross_attention.forward( + &normed_hidden_states, + position_bias, + Some(key_value_states), + None, + )?; + let ys = (hidden_states + ys)?; + Ok((ys, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.cross_attention.clear_kv_cache() + } +} + +#[derive(Debug)] +struct T5Block { + self_attn: T5LayerSelfAttention, + cross_attn: Option<T5LayerCrossAttention>, + ff: T5LayerFF, + span: tracing::Span, +} + +impl T5Block { + fn load( + has_relative_attention_bias: bool, + decoder: bool, + vb: VarBuilder, + cfg: &Config, + ) -> Result<Self> { + let vb = vb.pp("layer"); + let self_attn = + T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?; + let cross_attn = if cfg.is_decoder { + Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?) + } else { + None + }; + let ff_i = if cross_attn.is_some() { 2 } else { 1 }; + let ff = T5LayerFF::load(vb.pp(ff_i), cfg)?; + Ok(Self { + self_attn, + cross_attn, + ff, + span: tracing::span!(tracing::Level::TRACE, "block"), + }) + } + + fn forward( + &mut self, + xs: &Tensor, + position_bias: Option<&Tensor>, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<(Tensor, Option<Tensor>)> { + let _enter = self.span.enter(); + // TODO: Cache masks + let mask = match self.cross_attn.is_some() { + true => { + let mask_len = xs.dim(1)?; + // If the input seq length is 1, no need for a mask, this is also helpful to avoid shape + // issues when using the KV cache in the decoder. + if mask_len <= 1 { + None + } else { + Some(get_mask(mask_len, xs.device())?) + } + } + false => None, + }; + let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?; + // TODO: clamp for f16? + if let Some(cross_attn) = &mut self.cross_attn { + (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?; + // TODO: clamp for f16? + } + let xs = self.ff.forward(&xs)?; + // TODO: clamp for f16? + Ok((xs, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache(); + self.cross_attn.iter_mut().for_each(|c| c.clear_kv_cache()); + } +} + +#[derive(Debug)] +struct T5Stack { + block: Vec<T5Block>, + shared: Arc<Embedding>, + final_layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5Stack { + fn load(decoder: bool, vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> { + let block = (0..cfg.num_layers) + .map(|i| T5Block::load(i == 0, decoder, vb.pp(format!("block.{i}")), cfg)) + .collect::<Result<Vec<_>>>()?; + let final_layer_norm = T5LayerNorm::load( + cfg.d_model, + cfg.layer_norm_epsilon, + vb.pp("final_layer_norm"), + )?; + Ok(Self { + block, + shared: shared.clone(), + final_layer_norm, + span: tracing::span!(tracing::Level::TRACE, "stack"), + }) + } + + fn forward( + &mut self, + input_ids: &Tensor, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<Tensor> { + let _enter = self.span.enter(); + let input_embeds = self.shared.as_ref().forward(input_ids)?; + let mut hidden_states = input_embeds; + let mut position_bias = None; + for block in self.block.iter_mut() { + (hidden_states, position_bias) = block.forward( + &hidden_states, + position_bias.as_ref(), + encoder_hidden_states, + )? + } + self.final_layer_norm.forward(&hidden_states) + } + + fn clear_kv_cache(&mut self) { + self.block.iter_mut().for_each(|b| b.clear_kv_cache()) + } +} + +#[derive(Debug)] +pub struct T5EncoderModel { + encoder: T5Stack, + device: Device, + span: tracing::Span, +} + +impl T5EncoderModel { + pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let shared = Arc::new(shared); + let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?; + Ok(Self { + encoder, + device: vb.device.clone(), + span: tracing::span!(tracing::Level::TRACE, "encoder"), + }) + } + + pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.encoder.forward(input_ids, None) + } + + pub fn device(&self) -> &Device { + &self.device + } + + pub fn clear_kv_cache(&mut self) { + self.encoder.clear_kv_cache() + } +} + +#[derive(Debug)] +pub struct T5ForConditionalGeneration { + encoder: T5Stack, + decoder: T5Stack, + d_model: usize, + tie_word_embeddings: bool, + lm_head: Option<QMatMul>, + shared: Arc<Embedding>, + device: Device, + span_decode: tracing::Span, + span_decode_head: tracing::Span, +} + +impl T5ForConditionalGeneration { + pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + assert!(cfg.is_encoder_decoder); + let d_model = cfg.d_model; + let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let shared = Arc::new(shared); + + let mut encoder_cfg = cfg.clone(); + encoder_cfg.is_decoder = false; + encoder_cfg.use_cache = false; + encoder_cfg.is_encoder_decoder = false; + let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, &encoder_cfg)?; + + let mut decoder_cfg = cfg.clone(); + decoder_cfg.is_decoder = true; + decoder_cfg.is_encoder_decoder = false; + decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers); + let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?; + + let tie_word_embeddings = cfg.tie_word_embeddings; + let lm_head = if tie_word_embeddings { + None + } else { + Some(QMatMul::new(cfg.d_model, cfg.vocab_size, vb.pp("lm_head"))?) + }; + + Ok(Self { + encoder, + decoder, + d_model, + tie_word_embeddings, + lm_head, + shared, + device: vb.device.clone(), + span_decode: tracing::span!(tracing::Level::TRACE, "decode"), + span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"), + }) + } + + pub fn encode(&mut self, input_ids: &Tensor) -> Result<Tensor> { + self.encoder.forward(input_ids, None) + } + + pub fn decode( + &mut self, + decoder_input_ids: &Tensor, + encoder_output: &Tensor, + ) -> Result<Tensor> { + let _enter = self.span_decode.enter(); + let decoder_output = self + .decoder + .forward(decoder_input_ids, Some(encoder_output))?; + + let scaling_factor = if self.tie_word_embeddings { + // Rescale output before projecting on vocab + // See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + (self.d_model as f64).sqrt() + } else { + 1.0 + }; + let sequence_output = ((decoder_output + .narrow(1, decoder_output.dim(1)? - 1, 1)? + .squeeze(1)?) + * scaling_factor)?; + let output = { + let _enter = self.span_decode_head.enter(); + match self.lm_head { + None => sequence_output.matmul(&self.shared.embeddings().t()?)?, + Some(ref lm_head) => lm_head.forward(&sequence_output)?, + } + }; + + // TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5) + Ok(output) + } + + pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> { + let encoder_output = self.encode(input_ids)?; + self.decode(decoder_input_ids, &encoder_output) + } + + pub fn device(&self) -> &Device { + &self.device + } + + pub fn clear_kv_cache(&mut self) { + self.encoder.clear_kv_cache(); + self.decoder.clear_kv_cache(); + } +} diff --git a/candle-transformers/src/models/segment_anything/image_encoder.rs b/candle-transformers/src/models/segment_anything/image_encoder.rs new file mode 100644 index 00000000..0b313830 --- /dev/null +++ b/candle-transformers/src/models/segment_anything/image_encoder.rs @@ -0,0 +1,483 @@ +use candle::{DType, IndexOp, Result, Tensor}; +use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder}; + +#[derive(Debug)] +struct PatchEmbed { + proj: candle_nn::Conv2d, + span: tracing::Span, +} + +impl PatchEmbed { + fn new( + in_chans: usize, + embed_dim: usize, + k_size: usize, + stride: usize, + padding: usize, + vb: VarBuilder, + ) -> Result<Self> { + let cfg = candle_nn::Conv2dConfig { + stride, + padding, + ..Default::default() + }; + let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?; + let span = tracing::span!(tracing::Level::TRACE, "patch-embed"); + Ok(Self { proj, span }) + } +} + +impl Module for PatchEmbed { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + xs.apply(&self.proj)?.permute((0, 2, 3, 1)) + } +} + +// A custom op to make add_decomposed_rel_pos faster. Most of the time is spent on the final +// addition in the case where b = 12, q_h = q_w = 4096, k_h = k_w = 4096 +// (attn.reshape((b, q_h, q_w, k_h, k_w))? +// + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)? +// .reshape((b, q_h * q_w, k_h * k_w)) +// Ideally we would perform this operation in place but this is not supported in candle at the +// moment. We should also investigate using f16 rather than f32. +struct Add3(usize, usize, usize, usize, usize); +impl candle::CustomOp3 for Add3 { + fn name(&self) -> &'static str { + "add3" + } + + fn cpu_fwd( + &self, + s1: &candle::CpuStorage, + l1: &candle::Layout, + s2: &candle::CpuStorage, + l2: &candle::Layout, + s3: &candle::CpuStorage, + l3: &candle::Layout, + ) -> Result<(candle::CpuStorage, candle::Shape)> { + use rayon::prelude::*; + + let Add3(b, q_h, q_w, k_h, k_w) = *self; + let s1 = s1.as_slice::<f32>()?; + let s1 = match l1.contiguous_offsets() { + None => candle::bail!("input1 has to be contiguous"), + Some((o1, o2)) => &s1[o1..o2], + }; + let s2 = s2.as_slice::<f32>()?; + let s2 = match l2.contiguous_offsets() { + None => candle::bail!("input2 has to be contiguous"), + Some((o1, o2)) => &s2[o1..o2], + }; + let s3 = s3.as_slice::<f32>()?; + let s3 = match l3.contiguous_offsets() { + None => candle::bail!("input3 has to be contiguous"), + Some((o1, o2)) => &s3[o1..o2], + }; + let mut dst = vec![0f32; b * q_h * q_w * k_h * k_w]; + dst.par_chunks_exact_mut(k_h * k_w) + .enumerate() + .for_each(|(b_idx, dst)| { + let s1_idx = b_idx * k_h * k_w; + let s2_idx = b_idx * k_h; + let s3_idx = b_idx * k_w; + for h_idx in 0..k_h { + let s1_idx = s1_idx + h_idx * k_w; + let s2_idx = s2_idx + h_idx; + let dst_idx = h_idx * k_w; + for w_idx in 0..k_w { + let s1_idx = s1_idx + w_idx; + let s3_idx = s3_idx + w_idx; + let dst_idx = dst_idx + w_idx; + dst[dst_idx] = s1[s1_idx] + s2[s2_idx] + s3[s3_idx] + } + } + }); + let dst = candle::WithDType::to_cpu_storage_owned(dst); + Ok((dst, (b, q_h * q_w, k_h * k_w).into())) + } +} + +#[derive(Debug)] +struct Attention { + qkv: super::Linear, + proj: super::Linear, + num_heads: usize, + scale: f64, + rel_pos_hw: Option<(Tensor, Tensor)>, + span: tracing::Span, + span_matmul: tracing::Span, + span_rel_pos: tracing::Span, + span_softmax: tracing::Span, +} + +impl Attention { + fn new( + dim: usize, + num_heads: usize, + qkv_bias: bool, + use_rel_pos: bool, + input_size: (usize, usize), + vb: VarBuilder, + ) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "attention"); + let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul"); + let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos"); + let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm"); + let qkv = super::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; + let proj = super::linear(vb.pp("proj"), dim, dim, true)?; + let head_dim = dim / num_heads; + let scale = 1. / (head_dim as f64).sqrt(); + let rel_pos_hw = if use_rel_pos { + let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?; + let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?; + Some((h, w)) + } else { + None + }; + Ok(Self { + qkv, + proj, + num_heads, + scale, + rel_pos_hw, + span, + span_matmul, + span_rel_pos, + span_softmax, + }) + } + + fn add_decomposed_rel_pos( + &self, + attn: Tensor, + q: &Tensor, + (q_h, q_w): (usize, usize), + (k_h, k_w): (usize, usize), + ) -> Result<Tensor> { + match &self.rel_pos_hw { + Some((rel_pos_h, rel_pos_w)) => { + let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?; + let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?; + let (b, _, dim) = q.dims3()?; + let r_q = q.reshape((b, q_h, q_w, dim))?; + // rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + let rel_h = r_q.matmul(&r_h.broadcast_left(b)?.t()?.contiguous()?)?; + // rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + let rel_w = r_q + .transpose(1, 2)? // -> bwhc + .contiguous()? + .matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk + .transpose(1, 2)? + .contiguous()?; + if attn.device().is_cpu() { + let op = Add3(b, q_h, q_w, k_h, k_w); + attn.apply_op3_no_bwd(&rel_h, &rel_w, &op) + } else { + (attn.reshape((b, q_h, q_w, k_h, k_w))? + + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)? + .reshape((b, q_h * q_w, k_h * k_w)) + } + } + None => Ok(attn), + } + } +} + +fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor> { + let max_rel_dist = 2 * usize::max(q_size, k_size) - 1; + let dev = rel_pos.device(); + let rel_pos_resized = if rel_pos.dim(0)? != max_rel_dist { + todo!("interpolation") + } else { + rel_pos + }; + let q_coords = Tensor::arange(0u32, q_size as u32, dev)? + .reshape((q_size, 1))? + .to_dtype(DType::F32)?; + let k_coords = Tensor::arange(0u32, k_size as u32, dev)? + .reshape((1, k_size))? + .to_dtype(DType::F32)?; + let q_coords = (q_coords * f64::max(1f64, k_size as f64 / q_size as f64))?; + let k_coords = (k_coords * f64::max(1f64, q_size as f64 / k_size as f64))?; + let relative_coords = (q_coords.broadcast_sub(&k_coords)? + + (k_size as f64 - 1.) * f64::max(1f64, q_size as f64 / k_size as f64))?; + let (d1, d2) = relative_coords.dims2()?; + let relative_coords = relative_coords.to_dtype(DType::U32)?; + rel_pos_resized + .index_select(&relative_coords.reshape(d1 * d2)?, 0)? + .reshape((d1, d2, ())) +} + +impl Module for Attention { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let (b, h, w, c) = xs.dims4()?; + let qkv = self + .qkv + .forward(&xs.flatten_to(1)?)? + .reshape((b, h * w, 3, self.num_heads, c / self.num_heads))? + .permute((2, 0, 3, 1, 4))? + .reshape((3, b * self.num_heads, h * w, c / self.num_heads))?; + let q = qkv.i(0)?; + let k = qkv.i(1)?; + let v = qkv.i(2)?; + let attn = { + let _enter = self.span_matmul.enter(); + (&q * self.scale)?.matmul(&k.t()?)? + }; + let attn = { + let _enter = self.span_rel_pos.enter(); + self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))? + }; + let attn = { + let _enter = self.span_softmax.enter(); + candle_nn::ops::softmax_last_dim(&attn)? + }; + let attn = { + let _enter = self.span_matmul.enter(); + attn.matmul(&v)? + }; + let attn = attn + .reshape((b, self.num_heads, h, w, c / self.num_heads))? + .permute((0, 2, 3, 1, 4))? + .reshape((b, h * w, c))?; + self.proj.forward(&attn)?.reshape((b, h, w, c)) + } +} + +#[derive(Debug)] +struct Block { + norm1: LayerNorm, + attn: Attention, + norm2: LayerNorm, + mlp: super::MlpBlock, + window_size: usize, + span: tracing::Span, +} + +impl Block { + fn new( + dim: usize, + num_heads: usize, + qkv_bias: bool, + use_rel_pos: bool, + window_size: usize, + input_size: (usize, usize), + vb: VarBuilder, + ) -> Result<Self> { + let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?; + let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?; + let input_size_attn = if window_size == 0 { + input_size + } else { + (window_size, window_size) + }; + let attn = Attention::new( + dim, + num_heads, + qkv_bias, + use_rel_pos, + input_size_attn, + vb.pp("attn"), + )?; + let mlp = super::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?; + let span = tracing::span!(tracing::Level::TRACE, "ie-block"); + Ok(Self { + norm1, + attn, + norm2, + mlp, + window_size, + span, + }) + } +} + +fn window_partition(xs: Tensor, window_size: usize) -> Result<(Tensor, (usize, usize))> { + let (b, h, w, c) = xs.dims4()?; + let pad_h = (window_size - h % window_size) % window_size; + let pad_w = (window_size - w % window_size) % window_size; + let xs = if pad_h > 0 { + xs.pad_with_zeros(1, 0, pad_h)? + } else { + xs + }; + let xs = if pad_w > 0 { + xs.pad_with_zeros(2, 0, pad_w)? + } else { + xs + }; + let (h_p, w_p) = (h + pad_h, w + pad_w); + let windows = xs + .reshape(( + b, + h_p / window_size, + window_size, + w_p / window_size, + window_size, + c, + ))? + .transpose(2, 3)? + .contiguous()? + .flatten_to(2)?; + Ok((windows, (h_p, w_p))) +} + +fn window_unpartition( + windows: Tensor, + window_size: usize, + (h_p, w_p): (usize, usize), + (h, w): (usize, usize), +) -> Result<Tensor> { + let b = windows.dim(0)? / (h_p * w_p / window_size / window_size); + let xs = windows + .reshape(( + b, + h_p / window_size, + w_p / window_size, + window_size, + window_size, + windows.elem_count() / b / h_p / w_p, + ))? + .transpose(2, 3)? + .contiguous()? + .reshape((b, h_p, w_p, ()))?; + let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs }; + let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs }; + Ok(xs) +} + +impl Module for Block { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let shortcut = xs; + let xs = self.norm1.forward(xs)?; + let hw = (xs.dim(1)?, xs.dim(2)?); + let (xs, pad_hw) = if self.window_size > 0 { + window_partition(xs, self.window_size)? + } else { + (xs, (0, 0)) + }; + let xs = self.attn.forward(&xs)?; + let xs = if self.window_size > 0 { + window_unpartition(xs, self.window_size, pad_hw, hw)? + } else { + xs + }; + let xs = (xs + shortcut)?; + &xs + xs.apply(&self.norm2)?.apply(&self.mlp)? + } +} + +#[derive(Debug)] +pub struct ImageEncoderViT { + patch_embed: PatchEmbed, + blocks: Vec<Block>, + neck_conv1: candle_nn::Conv2d, + neck_ln1: super::LayerNorm2d, + neck_conv2: candle_nn::Conv2d, + neck_ln2: super::LayerNorm2d, + pos_embed: Option<Tensor>, + span: tracing::Span, +} + +impl ImageEncoderViT { + #[allow(clippy::too_many_arguments)] + pub fn new( + img_size: usize, + patch_size: usize, + in_chans: usize, + embed_dim: usize, + depth: usize, + num_heads: usize, + out_chans: usize, + qkv_bias: bool, + use_rel_pos: bool, + use_abs_pos: bool, + window_size: usize, + global_attn_indexes: &[usize], + vb: VarBuilder, + ) -> Result<Self> { + let patch_embed = PatchEmbed::new( + in_chans, + embed_dim, + patch_size, + patch_size, + 0, + vb.pp("patch_embed"), + )?; + let mut blocks = Vec::with_capacity(depth); + let vb_b = vb.pp("blocks"); + for i in 0..depth { + let window_size = if global_attn_indexes.contains(&i) { + 0 + } else { + window_size + }; + let block = Block::new( + embed_dim, + num_heads, + qkv_bias, + use_rel_pos, + window_size, + (img_size / patch_size, img_size / patch_size), + vb_b.pp(i), + )?; + blocks.push(block) + } + let neck_conv1 = candle_nn::conv2d_no_bias( + embed_dim, + out_chans, + 1, + Default::default(), + vb.pp("neck.0"), + )?; + let neck_ln1 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?; + let cfg = candle_nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?; + let neck_ln2 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?; + let pos_embed = if use_abs_pos { + let p = vb.get( + (1, img_size / patch_size, img_size / patch_size, embed_dim), + "pos_embed", + )?; + Some(p) + } else { + None + }; + let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit"); + Ok(Self { + patch_embed, + blocks, + neck_conv1, + neck_ln1, + neck_conv2, + neck_ln2, + pos_embed, + span, + }) + } +} + +impl Module for ImageEncoderViT { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let xs = self.patch_embed.forward(xs)?; + let mut xs = match &self.pos_embed { + Some(pos_embed) => (xs + pos_embed)?, + None => xs, + }; + for block in self.blocks.iter() { + xs = block.forward(&xs)? + } + xs.permute((0, 3, 1, 2))? + .apply(&self.neck_conv1)? + .apply(&self.neck_ln1)? + .apply(&self.neck_conv2)? + .apply(&self.neck_ln2) + } +} diff --git a/candle-transformers/src/models/segment_anything/mask_decoder.rs b/candle-transformers/src/models/segment_anything/mask_decoder.rs new file mode 100644 index 00000000..2a91cd44 --- /dev/null +++ b/candle-transformers/src/models/segment_anything/mask_decoder.rs @@ -0,0 +1,239 @@ +use candle::{IndexOp, Result, Tensor}; +use candle_nn::{Module, VarBuilder}; + +use super::transformer::TwoWayTransformer; + +#[derive(Debug)] +struct MlpMaskDecoder { + layers: Vec<super::Linear>, + sigmoid_output: bool, + span: tracing::Span, +} + +impl MlpMaskDecoder { + fn new( + input_dim: usize, + hidden_dim: usize, + output_dim: usize, + num_layers: usize, + sigmoid_output: bool, + vb: VarBuilder, + ) -> Result<Self> { + let mut layers = Vec::with_capacity(num_layers); + let vb = vb.pp("layers"); + for i in 0..num_layers { + let in_dim = if i == 0 { input_dim } else { hidden_dim }; + let out_dim = if i + 1 == num_layers { + output_dim + } else { + hidden_dim + }; + let layer = super::linear(vb.pp(i), in_dim, out_dim, true)?; + layers.push(layer) + } + let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder"); + Ok(Self { + layers, + sigmoid_output, + span, + }) + } +} + +impl Module for MlpMaskDecoder { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut xs = xs.clone(); + for (i, layer) in self.layers.iter().enumerate() { + xs = layer.forward(&xs)?; + if i + 1 < self.layers.len() { + xs = xs.relu()? + } + } + if self.sigmoid_output { + candle_nn::ops::sigmoid(&xs) + } else { + Ok(xs) + } + } +} + +#[derive(Debug)] +pub struct MaskDecoder { + iou_token: candle_nn::Embedding, + mask_tokens: candle_nn::Embedding, + iou_prediction_head: MlpMaskDecoder, + output_upscaling_conv1: candle_nn::ConvTranspose2d, + output_upscaling_ln: super::LayerNorm2d, + output_upscaling_conv2: candle_nn::ConvTranspose2d, + num_mask_tokens: usize, + output_hypernetworks_mlps: Vec<MlpMaskDecoder>, + transformer: TwoWayTransformer, + span: tracing::Span, +} + +impl MaskDecoder { + pub fn new( + transformer_dim: usize, + num_multimask_outputs: usize, + iou_head_depth: usize, + iou_head_hidden_dim: usize, + vb: VarBuilder, + ) -> Result<Self> { + let num_mask_tokens = num_multimask_outputs + 1; + let iou_prediction_head = MlpMaskDecoder::new( + transformer_dim, + iou_head_hidden_dim, + num_mask_tokens, + iou_head_depth, + false, + vb.pp("iou_prediction_head"), + )?; + let iou_token = candle_nn::embedding(1, transformer_dim, vb.pp("iou_token"))?; + let mask_tokens = + candle_nn::embedding(num_mask_tokens, transformer_dim, vb.pp("mask_tokens"))?; + let cfg = candle_nn::ConvTranspose2dConfig { + stride: 2, + ..Default::default() + }; + let output_upscaling_conv1 = candle_nn::conv_transpose2d( + transformer_dim, + transformer_dim / 4, + 2, + cfg, + vb.pp("output_upscaling.0"), + )?; + let output_upscaling_ln = + super::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?; + let output_upscaling_conv2 = candle_nn::conv_transpose2d( + transformer_dim / 4, + transformer_dim / 8, + 2, + cfg, + vb.pp("output_upscaling.3"), + )?; + let mut output_hypernetworks_mlps = Vec::with_capacity(num_mask_tokens); + let vb_o = vb.pp("output_hypernetworks_mlps"); + for i in 0..num_mask_tokens { + let mlp = MlpMaskDecoder::new( + transformer_dim, + transformer_dim, + transformer_dim / 8, + 3, + false, + vb_o.pp(i), + )?; + output_hypernetworks_mlps.push(mlp) + } + let transformer = TwoWayTransformer::new( + /* depth */ 2, + /* embedding_dim */ transformer_dim, + /* num_heads */ 8, + /* mlp_dim */ 2048, + vb.pp("transformer"), + )?; + let span = tracing::span!(tracing::Level::TRACE, "mask-decoder"); + Ok(Self { + iou_token, + mask_tokens, + iou_prediction_head, + output_upscaling_conv1, + output_upscaling_ln, + output_upscaling_conv2, + num_mask_tokens, + output_hypernetworks_mlps, + transformer, + span, + }) + } + + pub fn forward( + &self, + image_embeddings: &Tensor, + image_pe: &Tensor, + sparse_prompt_embeddings: &Tensor, + dense_prompt_embeddings: &Tensor, + multimask_output: bool, + ) -> Result<(Tensor, Tensor)> { + let _enter = self.span.enter(); + let (masks, iou_pred) = self.predict_masks( + image_embeddings, + image_pe, + sparse_prompt_embeddings, + dense_prompt_embeddings, + )?; + let masks = if multimask_output { + masks.i((.., 1..))? + } else { + masks.i((.., 0..1))? + }; + let iou_pred = if multimask_output { + iou_pred.i((.., 1..))? + } else { + iou_pred.i((.., 0..1))? + }; + Ok((masks, iou_pred)) + } + + fn predict_masks( + &self, + image_embeddings: &Tensor, + image_pe: &Tensor, + sparse_prompt_embeddings: &Tensor, + dense_prompt_embeddings: &Tensor, + ) -> Result<(Tensor, Tensor)> { + // Concatenate ouput tokens. + let output_tokens = Tensor::cat( + &[self.iou_token.embeddings(), self.mask_tokens.embeddings()], + 0, + )?; + let (d1, d2) = output_tokens.dims2()?; + let output_tokens = + output_tokens + .unsqueeze(0)? + .expand((sparse_prompt_embeddings.dim(0)?, d1, d2))?; + let tokens = Tensor::cat(&[&output_tokens, sparse_prompt_embeddings], 1)?; + + // Expand per-image data in batch direction to be per mask + let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?; + let src = src.broadcast_add(dense_prompt_embeddings)?; + let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?; + let (b, c, h, w) = src.dims4()?; + + // Run the transformer + let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?; + let iou_token_out = hs.i((.., 0))?; + let mask_tokens_out = hs.i((.., 1..1 + self.num_mask_tokens))?; + + // Upscale mask embeddings and predict masks using the masks tokens. + let src = src.transpose(1, 2)?.reshape((b, c, h, w))?; + let upscaled_embedding = self + .output_upscaling_conv1 + .forward(&src)? + .apply(&self.output_upscaling_ln)? + .gelu()? + .apply(&self.output_upscaling_conv2)? + .gelu()?; + let mut hyper_in_list = Vec::with_capacity(self.num_mask_tokens); + for (i, mlp) in self.output_hypernetworks_mlps.iter().enumerate() { + let h = mlp.forward(&mask_tokens_out.i((.., i))?)?; + hyper_in_list.push(h) + } + let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?.contiguous()?; + let (b, c, h, w) = upscaled_embedding.dims4()?; + let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?; + let masks = masks.reshape((b, (), h, w))?; + + // Generate mask quality predictions. + let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?; + Ok((masks, iou_pred)) + } +} + +// Equivalent to torch.repeat_interleave +fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> { + let img = img.unsqueeze(dim + 1)?; + let mut dims = img.dims().to_vec(); + dims[dim + 1] = repeats; + img.broadcast_as(dims)?.flatten(dim, dim + 1) +} diff --git a/candle-transformers/src/models/segment_anything/mod.rs b/candle-transformers/src/models/segment_anything/mod.rs new file mode 100644 index 00000000..c29db70a --- /dev/null +++ b/candle-transformers/src/models/segment_anything/mod.rs @@ -0,0 +1,100 @@ +use candle::{Result, Tensor}; +use candle_nn::{Module, VarBuilder}; + +pub mod image_encoder; +pub mod mask_decoder; +pub mod prompt_encoder; +pub mod sam; +pub mod tiny_vit; +pub mod transformer; + +pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> { + let inner = if bias { + candle_nn::linear(in_dim, out_dim, vb)? + } else { + candle_nn::linear_no_bias(in_dim, out_dim, vb)? + }; + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Ok(Linear { inner, span }) +} + +#[derive(Debug)] +pub struct LayerNorm2d { + weight: Tensor, + bias: Tensor, + num_channels: usize, + eps: f64, +} + +impl LayerNorm2d { + pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let weight = vb.get(num_channels, "weight")?; + let bias = vb.get(num_channels, "bias")?; + Ok(Self { + weight, + bias, + num_channels, + eps, + }) + } +} + +impl Module for LayerNorm2d { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let u = xs.mean_keepdim(1)?; + let xs = xs.broadcast_sub(&u)?; + let s = xs.sqr()?.mean_keepdim(1)?; + let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?; + xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)? + .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?) + } +} + +#[derive(Debug)] +pub struct MlpBlock { + lin1: Linear, + lin2: Linear, + activation: candle_nn::Activation, + span: tracing::Span, +} + +impl MlpBlock { + pub fn new( + embedding_dim: usize, + mlp_dim: usize, + activation: candle_nn::Activation, + vb: VarBuilder, + ) -> Result<Self> { + let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?; + let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?; + let span = tracing::span!(tracing::Level::TRACE, "mlp-block"); + Ok(Self { + lin1, + lin2, + activation, + span, + }) + } +} + +impl Module for MlpBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + xs.apply(&self.lin1)? + .apply(&self.activation)? + .apply(&self.lin2) + } +} + +#[derive(Debug)] +pub struct Linear { + inner: candle_nn::Linear, + span: tracing::Span, +} + +impl Module for Linear { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} diff --git a/candle-transformers/src/models/segment_anything/prompt_encoder.rs b/candle-transformers/src/models/segment_anything/prompt_encoder.rs new file mode 100644 index 00000000..9d0074b1 --- /dev/null +++ b/candle-transformers/src/models/segment_anything/prompt_encoder.rs @@ -0,0 +1,239 @@ +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn::VarBuilder; + +#[derive(Debug)] +struct PostionEmbeddingRandom { + positional_encoding_gaussian_matrix: Tensor, +} + +impl PostionEmbeddingRandom { + fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> { + let positional_encoding_gaussian_matrix = + vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?; + Ok(Self { + positional_encoding_gaussian_matrix, + }) + } + + fn pe_encoding(&self, coords: &Tensor) -> Result<Tensor> { + let coords = coords.affine(2., -1.)?; + let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?; + let coords = (coords * (2. * std::f64::consts::PI))?; + Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1) + } + + fn forward(&self, h: usize, w: usize) -> Result<Tensor> { + let device = self.positional_encoding_gaussian_matrix.device(); + let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?; + let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?; + let x_embed = (x_embed / w as f64)? + .reshape((1, ()))? + .broadcast_as((h, w))?; + let y_embed = (y_embed / h as f64)? + .reshape(((), 1))? + .broadcast_as((h, w))?; + let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?; + self.pe_encoding(&coords)?.permute((2, 0, 1)) + } + + fn forward_with_coords( + &self, + coords_input: &Tensor, + image_size: (usize, usize), + ) -> Result<Tensor> { + let coords0 = (coords_input.narrow(D::Minus1, 0, 1)? / image_size.1 as f64)?; + let coords1 = (coords_input.narrow(D::Minus1, 1, 1)? / image_size.0 as f64)?; + let c = coords_input.dim(D::Minus1)?; + let coords_rest = coords_input.narrow(D::Minus1, 2, c - 2)?; + let coords = Tensor::cat(&[&coords0, &coords1, &coords_rest], D::Minus1)?; + self.pe_encoding(&coords) + } +} + +#[derive(Debug)] +pub struct PromptEncoder { + pe_layer: PostionEmbeddingRandom, + point_embeddings: Vec<candle_nn::Embedding>, + not_a_point_embed: candle_nn::Embedding, + mask_downscaling_conv1: candle_nn::Conv2d, + mask_downscaling_ln1: super::LayerNorm2d, + mask_downscaling_conv2: candle_nn::Conv2d, + mask_downscaling_ln2: super::LayerNorm2d, + mask_downscaling_conv3: candle_nn::Conv2d, + no_mask_embed: candle_nn::Embedding, + image_embedding_size: (usize, usize), + input_image_size: (usize, usize), + embed_dim: usize, + span: tracing::Span, +} + +impl PromptEncoder { + pub fn new( + embed_dim: usize, + image_embedding_size: (usize, usize), + input_image_size: (usize, usize), + mask_in_chans: usize, + vb: VarBuilder, + ) -> Result<Self> { + let num_points_embeddings = 4; + let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?; + let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?; + let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?; + let cfg = candle_nn::Conv2dConfig { + stride: 2, + ..Default::default() + }; + let mask_downscaling_conv1 = + candle_nn::conv2d(1, mask_in_chans / 4, 2, cfg, vb.pp("mask_downscaling.0"))?; + let mask_downscaling_conv2 = candle_nn::conv2d( + mask_in_chans / 4, + mask_in_chans, + 2, + cfg, + vb.pp("mask_downscaling.3"), + )?; + let mask_downscaling_conv3 = candle_nn::conv2d( + mask_in_chans, + embed_dim, + 1, + Default::default(), + vb.pp("mask_downscaling.6"), + )?; + let mask_downscaling_ln1 = + super::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?; + let mask_downscaling_ln2 = + super::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?; + let mut point_embeddings = Vec::with_capacity(num_points_embeddings); + let vb_e = vb.pp("point_embeddings"); + for i in 0..num_points_embeddings { + let emb = candle_nn::embedding(1, embed_dim, vb_e.pp(i))?; + point_embeddings.push(emb) + } + let span = tracing::span!(tracing::Level::TRACE, "prompt-encoder"); + Ok(Self { + pe_layer, + point_embeddings, + not_a_point_embed, + mask_downscaling_conv1, + mask_downscaling_ln1, + mask_downscaling_conv2, + mask_downscaling_ln2, + mask_downscaling_conv3, + no_mask_embed, + image_embedding_size, + input_image_size, + embed_dim, + span, + }) + } + + pub fn get_dense_pe(&self) -> Result<Tensor> { + self.pe_layer + .forward(self.image_embedding_size.0, self.image_embedding_size.1)? + .unsqueeze(0) + } + + fn embed_masks(&self, masks: &Tensor) -> Result<Tensor> { + masks + .apply(&self.mask_downscaling_conv1)? + .apply(&self.mask_downscaling_ln1)? + .gelu()? + .apply(&self.mask_downscaling_conv2)? + .apply(&self.mask_downscaling_ln2)? + .gelu()? + .apply(&self.mask_downscaling_conv3) + } + + fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result<Tensor> { + let points = (points + 0.5)?; + let dev = points.device(); + let (points, labels) = if pad { + let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?; + let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?; + let points = Tensor::cat(&[&points, &padding_point], 1)?; + let labels = Tensor::cat(&[labels, &padding_label], 1)?; + (points, labels) + } else { + (points, labels.clone()) + }; + let point_embedding = self + .pe_layer + .forward_with_coords(&points, self.input_image_size)?; + let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?; + let zeros = point_embedding.zeros_like()?; + let point_embedding = labels.lt(0f32)?.where_cond( + &self + .not_a_point_embed + .embeddings() + .broadcast_as(zeros.shape())?, + &point_embedding, + )?; + let labels0 = labels.eq(0f32)?.where_cond( + &self.point_embeddings[0] + .embeddings() + .broadcast_as(zeros.shape())?, + &zeros, + )?; + let point_embedding = (point_embedding + labels0)?; + let labels1 = labels.eq(1f32)?.where_cond( + &self.point_embeddings[1] + .embeddings() + .broadcast_as(zeros.shape())?, + &zeros, + )?; + let point_embedding = (point_embedding + labels1)?; + Ok(point_embedding) + } + + fn embed_boxes(&self, boxes: &Tensor) -> Result<Tensor> { + let boxes = (boxes + 0.5)?; + let coords = boxes.reshape(((), 2, 2))?; + let corner_embedding = self + .pe_layer + .forward_with_coords(&coords, self.input_image_size)?; + let ce1 = corner_embedding.i((.., 0))?; + let ce2 = corner_embedding.i((.., 1))?; + let ce1 = (ce1 + self.point_embeddings[2].embeddings())?; + let ce2 = (ce2 + self.point_embeddings[3].embeddings())?; + Tensor::cat(&[&ce1, &ce2], 1) + } + + pub fn forward( + &self, + points: Option<(&Tensor, &Tensor)>, + boxes: Option<&Tensor>, + masks: Option<&Tensor>, + ) -> Result<(Tensor, Tensor)> { + let _enter = self.span.enter(); + let se_points = match points { + Some((coords, labels)) => Some(self.embed_points(coords, labels, boxes.is_none())?), + None => None, + }; + let se_boxes = match boxes { + Some(boxes) => Some(self.embed_boxes(boxes)?), + None => None, + }; + let sparse_embeddings = match (se_points, se_boxes) { + (Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?, + (Some(se_points), None) => se_points, + (None, Some(se_boxes)) => se_boxes, + (None, None) => { + Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)? + } + }; + + let dense_embeddings = match masks { + None => { + let emb = self.no_mask_embed.embeddings(); + emb.reshape((1, (), 1, 1))?.expand(( + 1, + emb.elem_count(), + self.image_embedding_size.0, + self.image_embedding_size.1, + ))? + } + Some(masks) => self.embed_masks(masks)?, + }; + Ok((sparse_embeddings, dense_embeddings)) + } +} diff --git a/candle-transformers/src/models/segment_anything/sam.rs b/candle-transformers/src/models/segment_anything/sam.rs new file mode 100644 index 00000000..07e9a759 --- /dev/null +++ b/candle-transformers/src/models/segment_anything/sam.rs @@ -0,0 +1,433 @@ +use candle::{DType, IndexOp, Result, Tensor}; +use candle_nn::{Module, VarBuilder}; + +use super::image_encoder::ImageEncoderViT; +use super::mask_decoder::MaskDecoder; +use super::prompt_encoder::PromptEncoder; +use super::tiny_vit::{tiny_vit_5m, TinyViT}; + +const PROMPT_EMBED_DIM: usize = 256; +pub const IMAGE_SIZE: usize = 1024; +const VIT_PATCH_SIZE: usize = 16; +const PRED_IOU_THRESH: f32 = 0.88; +const STABILITY_SCORE_OFFSET: f32 = 1.0; +const STABILITY_SCORE_THRESHOLD: f32 = 0.95; +const MODEL_MASK_THRESHOLD: f32 = 0.0; +const CROP_NMS_THRESH: f32 = 0.7; + +#[derive(Debug)] +enum ImageEncoder { + Original(ImageEncoderViT), + TinyViT(TinyViT), +} + +impl Module for ImageEncoder { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + match self { + Self::Original(vit) => vit.forward(xs), + Self::TinyViT(vit) => vit.forward(xs), + } + } +} + +#[derive(Debug)] +pub struct Sam { + image_encoder: ImageEncoder, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: Tensor, + pixel_std: Tensor, +} + +impl Sam { + pub fn new( + encoder_embed_dim: usize, + encoder_depth: usize, + encoder_num_heads: usize, + encoder_global_attn_indexes: &[usize], + vb: VarBuilder, + ) -> Result<Self> { + let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE; + + let image_encoder = ImageEncoderViT::new( + IMAGE_SIZE, + VIT_PATCH_SIZE, + 3, + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + PROMPT_EMBED_DIM, + /* qkv_bias */ true, + /* use_rel_pos */ true, + /* use_abs_pos */ true, + /* window_size */ 14, + /* global_attn_indexes */ encoder_global_attn_indexes, + vb.pp("image_encoder"), + )?; + let prompt_encoder = PromptEncoder::new( + PROMPT_EMBED_DIM, + (image_embedding_size, image_embedding_size), + (IMAGE_SIZE, IMAGE_SIZE), + 16, + vb.pp("prompt_encoder"), + )?; + let mask_decoder = MaskDecoder::new( + PROMPT_EMBED_DIM, + /* num_multitask_outputs */ 3, + /* iou_head_depth */ 3, + /* iou_head_hidden_dim */ 256, + vb.pp("mask_decoder"), + )?; + let pixel_mean = + Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?; + let pixel_std = + Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?; + Ok(Self { + image_encoder: ImageEncoder::Original(image_encoder), + prompt_encoder, + mask_decoder, + pixel_std, + pixel_mean, + }) + } + + pub fn new_tiny(vb: VarBuilder) -> Result<Self> { + let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE; + + let image_encoder = tiny_vit_5m(vb.pp("image_encoder"))?; + let prompt_encoder = PromptEncoder::new( + PROMPT_EMBED_DIM, + (image_embedding_size, image_embedding_size), + (IMAGE_SIZE, IMAGE_SIZE), + 16, + vb.pp("prompt_encoder"), + )?; + let mask_decoder = MaskDecoder::new( + PROMPT_EMBED_DIM, + /* num_multitask_outputs */ 3, + /* iou_head_depth */ 3, + /* iou_head_hidden_dim */ 256, + vb.pp("mask_decoder"), + )?; + let pixel_mean = + Tensor::new(&[123.675f32, 116.28, 103.53], vb.device())?.reshape((3, 1, 1))?; + let pixel_std = + Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?; + Ok(Self { + image_encoder: ImageEncoder::TinyViT(image_encoder), + prompt_encoder, + mask_decoder, + pixel_std, + pixel_mean, + }) + } + + pub fn embeddings(&self, img: &Tensor) -> Result<Tensor> { + let img = self.preprocess(img)?.unsqueeze(0)?; + self.image_encoder.forward(&img) + } + + pub fn forward( + &self, + img: &Tensor, + point: Option<(f64, f64)>, + multimask_output: bool, + ) -> Result<(Tensor, Tensor)> { + let (_c, original_h, original_w) = img.dims3()?; + let img = self.preprocess(img)?.unsqueeze(0)?; + let img_embeddings = self.image_encoder.forward(&img)?; + let (low_res_mask, iou) = self.forward_for_embeddings( + &img_embeddings, + original_h, + original_w, + point, + multimask_output, + )?; + let mask = low_res_mask + .upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)? + .get(0)? + .i((.., ..original_h, ..original_w))?; + Ok((mask, iou)) + } + + pub fn forward_for_embeddings( + &self, + img_embeddings: &Tensor, + original_h: usize, + original_w: usize, + point: Option<(f64, f64)>, + multimask_output: bool, + ) -> Result<(Tensor, Tensor)> { + let image_pe = self.prompt_encoder.get_dense_pe()?; + let points = match point { + None => None, + Some((x, y)) => { + let points = Tensor::new( + &[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]], + img_embeddings.device(), + )?; + let labels = Tensor::ones((1, 1), DType::F32, img_embeddings.device())?; + Some((points, labels)) + } + }; + let points = points.as_ref().map(|(x, y)| (x, y)); + let (sparse_prompt_embeddings, dense_prompt_embeddings) = + self.prompt_encoder.forward(points, None, None)?; + self.mask_decoder.forward( + img_embeddings, + &image_pe, + &sparse_prompt_embeddings, + &dense_prompt_embeddings, + multimask_output, + ) + } + + pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> { + let img = img + .broadcast_mul(&self.pixel_std)? + .broadcast_add(&self.pixel_mean)?; + img.maximum(&img.zeros_like()?)? + .minimum(&(img.ones_like()? * 255.)?) + } + + pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> { + let (_c, h, w) = img.dims3()?; + let img = img + .to_dtype(DType::F32)? + .broadcast_sub(&self.pixel_mean)? + .broadcast_div(&self.pixel_std)?; + if h > IMAGE_SIZE || w > IMAGE_SIZE { + candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}") + } + let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?; + img.pad_with_zeros(2, 0, IMAGE_SIZE - w) + } + + fn process_crop( + &self, + img: &Tensor, + cb: CropBox, + point_grids: &[(f64, f64)], + ) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> { + // Crop the image and calculate embeddings. + let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?; + let img = self.preprocess(&img)?.unsqueeze(0)?; + let img_embeddings = self.image_encoder.forward(&img)?; + + let crop_w = cb.x1 - cb.x0; + let crop_h = cb.y1 - cb.y0; + + // Generate masks for this crop. + let image_pe = self.prompt_encoder.get_dense_pe()?; + let points = point_grids + .iter() + .map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32]) + .collect::<Vec<_>>(); + + let mut bboxes = Vec::new(); + for points in points.chunks(64) { + // Run the model on this batch. + let points_len = points.len(); + let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?; + let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?; + let (sparse_prompt_embeddings, dense_prompt_embeddings) = + self.prompt_encoder + .forward(Some((&in_points, &in_labels)), None, None)?; + + let (low_res_mask, iou_predictions) = self.mask_decoder.forward( + &img_embeddings, + &image_pe, + &sparse_prompt_embeddings, + &dense_prompt_embeddings, + /* multimask_output */ true, + )?; + let low_res_mask = low_res_mask.flatten(0, 1)?; + let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?; + let dev = low_res_mask.device(); + + for (i, iou) in iou_predictions.iter().enumerate() { + // Filter by predicted IoU. + if *iou < PRED_IOU_THRESH { + continue; + } + let low_res_mask = low_res_mask.get(i)?; + + // Calculate stability score. + let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)? + .broadcast_as(low_res_mask.shape())?; + let intersections = low_res_mask + .ge(&bound)? + .to_dtype(DType::F32)? + .sum_all()? + .to_vec0::<f32>()?; + let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)? + .broadcast_as(low_res_mask.shape())?; + let unions = low_res_mask + .ge(&bound)? + .to_dtype(DType::F32)? + .sum_all()? + .to_vec0::<f32>()?; + let stability_score = intersections / unions; + if stability_score < STABILITY_SCORE_THRESHOLD { + continue; + } + + // Threshold masks and calculate boxes. + let low_res_mask = low_res_mask + .ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)? + .to_dtype(DType::U32)?; + let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?; + let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?; + let min_max_x = min_max_indexes(&low_res_mask_per_x); + let min_max_y = min_max_indexes(&low_res_mask_per_y); + if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) { + let bbox = crate::object_detection::Bbox { + xmin: x0 as f32, + ymin: y0 as f32, + xmax: x1 as f32, + ymax: y1 as f32, + confidence: *iou, + data: low_res_mask, + }; + bboxes.push(bbox); + } + // TODO: + // Filter boxes that touch crop boundaries + // Compress to RLE. + } + } + + let mut bboxes = vec![bboxes]; + // Remove duplicates within this crop. + crate::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH); + + // TODO: Return to the original image frame. + Ok(bboxes.remove(0)) + } + + pub fn generate_masks( + &self, + img: &Tensor, + points_per_side: usize, + crop_n_layer: usize, + crop_overlap_ratio: f64, + crop_n_points_downscale_factor: usize, + ) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> { + let (_c, h, w) = img.dims3()?; + let point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layer, + crop_n_points_downscale_factor, + ); + let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio); + let mut bboxes = Vec::new(); + for crop_box in crop_boxes.into_iter() { + let layer_idx = crop_box.layer_idx; + let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?; + bboxes.extend(b) + } + // TODO: remove duplicates + Ok(bboxes) + } +} + +// Return the first and last indexes i for which values[i] > 0 +fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> { + let (mut min_i, mut max_i) = (usize::MAX, usize::MIN); + for (i, &s) in values.iter().enumerate() { + if s == 0 { + continue; + } + min_i = usize::min(i, min_i); + max_i = usize::max(i, max_i); + } + if max_i < min_i { + None + } else { + Some((min_i, max_i)) + } +} + +#[derive(Debug)] +struct CropBox { + x0: usize, + y0: usize, + x1: usize, + y1: usize, + layer_idx: usize, +} + +impl CropBox { + fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self { + Self { + x0, + y0, + x1, + y1, + layer_idx, + } + } +} + +fn generate_crop_boxes( + (im_h, im_w): (usize, usize), + n_layers: usize, + overlap_ratio: f64, +) -> Vec<CropBox> { + fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize { + f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize + } + + let short_side = usize::min(im_h, im_w); + + let mut crop_boxes = Vec::new(); + + // Original image. + crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0)); + + for layer_idx in 1..=n_layers { + let n_crops_per_side = 1 << layer_idx; + let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize; + let crop_w = crop_len(im_w, n_crops_per_side, overlap); + let crop_h = crop_len(im_w, n_crops_per_side, overlap); + + for i_x in 0..n_crops_per_side { + let x0 = (crop_w - overlap) * i_x; + for i_y in 0..n_crops_per_side { + let y0 = (crop_h - overlap) * i_y; + let x1 = usize::min(im_w, x0 + crop_w); + let y1 = usize::min(im_h, y0 + crop_h); + crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx)); + } + } + } + + crop_boxes +} + +// Generates a 2D grid of points evenly spaced in [0,1]x[0,1]. +fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> { + let offset = 1f64 / (2 * n_per_side) as f64; + let mut points = Vec::with_capacity(n_per_side * n_per_side); + for i_x in 0..n_per_side { + let x = offset + i_x as f64 / n_per_side as f64; + for i_y in 0..n_per_side { + let y = offset + i_y as f64 / n_per_side as f64; + points.push((x, y)) + } + } + points +} + +fn build_all_layer_point_grids( + n_per_side: usize, + n_layers: usize, + scale_per_layer: usize, +) -> Vec<Vec<(f64, f64)>> { + let mut points_by_layer = Vec::with_capacity(n_layers + 1); + for i in 0..=n_layers { + let n_points = n_per_side / scale_per_layer.pow(i as u32); + points_by_layer.push(build_point_grid(n_points)) + } + points_by_layer +} diff --git a/candle-transformers/src/models/segment_anything/tiny_vit.rs b/candle-transformers/src/models/segment_anything/tiny_vit.rs new file mode 100644 index 00000000..cd2936ab --- /dev/null +++ b/candle-transformers/src/models/segment_anything/tiny_vit.rs @@ -0,0 +1,633 @@ +// Adapted from: +// https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/modeling/tiny_vit_sam.py +use candle::{IndexOp, Result, Tensor, D}; +use candle_nn::{Conv2dConfig, Module, VarBuilder}; + +const MBCONV_EXPAND_RATIO: usize = 4; +const MLP_RATIO: usize = 4; +const LOCAL_CONV_SIZE: usize = 3; +const IMG_SIZE: usize = 1024; +const IN_CHANNELS: usize = 3; + +#[derive(Debug)] +struct Conv2dBN { + c: candle_nn::Conv2d, + bn: candle_nn::BatchNorm, + span: tracing::Span, +} + +impl Conv2dBN { + fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> { + let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?; + let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?; + let span = tracing::span!(tracing::Level::TRACE, "conv2d-bn"); + Ok(Self { c, bn, span }) + } +} + +impl Module for Conv2dBN { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + xs.apply(&self.c)?.apply(&self.bn) + } +} + +#[derive(Debug)] +struct PatchEmbed { + conv1: Conv2dBN, + conv2: Conv2dBN, + span: tracing::Span, +} + +impl PatchEmbed { + fn new(in_chans: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> { + let cfg = candle_nn::Conv2dConfig { + stride: 2, + padding: 1, + ..Default::default() + }; + let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?; + let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?; + let span = tracing::span!(tracing::Level::TRACE, "patch-embed"); + Ok(Self { conv1, conv2, span }) + } +} + +impl Module for PatchEmbed { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2) + } +} + +#[derive(Debug)] +struct MBConv { + conv1: Conv2dBN, + conv2: Conv2dBN, + conv3: Conv2dBN, + span: tracing::Span, +} + +impl MBConv { + fn new(in_: usize, out: usize, expand_ratio: usize, vb: VarBuilder) -> Result<Self> { + let hidden = in_ * expand_ratio; + let cfg2 = candle_nn::Conv2dConfig { + padding: 1, + groups: hidden, + ..Default::default() + }; + let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?; + let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?; + let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?; + let span = tracing::span!(tracing::Level::TRACE, "mb-conv"); + Ok(Self { + conv1, + conv2, + conv3, + span, + }) + } +} + +impl Module for MBConv { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let shortcut = xs; + let xs = xs + .apply(&self.conv1)? + .gelu()? + .apply(&self.conv2)? + .gelu()? + .apply(&self.conv3)?; + (xs + shortcut)?.gelu() + } +} + +#[derive(Debug)] +struct PatchMerging { + conv1: Conv2dBN, + conv2: Conv2dBN, + conv3: Conv2dBN, + input_resolution: (usize, usize), + span: tracing::Span, +} + +impl PatchMerging { + fn new( + input_resolution: (usize, usize), + dim: usize, + out: usize, + vb: VarBuilder, + ) -> Result<Self> { + let stride = if [320, 448, 576].contains(&out) { 1 } else { 2 }; + let cfg2 = candle_nn::Conv2dConfig { + padding: 1, + stride, + groups: out, + ..Default::default() + }; + let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?; + let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?; + let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?; + let span = tracing::span!(tracing::Level::TRACE, "patch-merging"); + Ok(Self { + conv1, + conv2, + conv3, + input_resolution, + span, + }) + } +} + +impl Module for PatchMerging { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let xs = if xs.rank() == 3 { + let (h, w) = self.input_resolution; + let b = xs.dim(0)?; + xs.reshape((b, h, w, ()))?.permute((0, 3, 1, 2))? + } else { + xs.clone() + }; + xs.apply(&self.conv1)? + .gelu()? + .apply(&self.conv2)? + .gelu()? + .apply(&self.conv3)? + .flatten_from(2)? + .transpose(1, 2) + } +} + +#[derive(Debug)] +struct ConvLayer { + blocks: Vec<MBConv>, + downsample: Option<PatchMerging>, + span: tracing::Span, +} + +impl ConvLayer { + fn new( + dim: usize, + out: usize, + input_resolution: (usize, usize), + depth: usize, + downsample: bool, + conv_expand_ratio: usize, + vb: VarBuilder, + ) -> Result<Self> { + let vb_b = vb.pp("blocks"); + let mut blocks = Vec::with_capacity(depth); + for index in 0..depth { + let block = MBConv::new(dim, dim, conv_expand_ratio, vb_b.pp(index))?; + blocks.push(block) + } + let downsample = if downsample { + let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?; + Some(downsample) + } else { + None + }; + let span = tracing::span!(tracing::Level::TRACE, "conv-layer"); + Ok(Self { + blocks, + downsample, + span, + }) + } +} + +impl Module for ConvLayer { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut xs = xs.clone(); + for block in self.blocks.iter() { + xs = block.forward(&xs)? + } + match &self.downsample { + None => Ok(xs), + Some(downsample) => downsample.forward(&xs), + } + } +} + +#[derive(Debug)] +struct Mlp { + norm: candle_nn::LayerNorm, + fc1: super::Linear, + fc2: super::Linear, + span: tracing::Span, +} + +impl Mlp { + fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> { + let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?; + let fc1 = super::linear(vb.pp("fc1"), in_, hidden, true)?; + let fc2 = super::linear(vb.pp("fc2"), hidden, in_, true)?; + let span = tracing::span!(tracing::Level::TRACE, "mlp"); + Ok(Self { + norm, + fc1, + fc2, + span, + }) + } +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + xs.apply(&self.norm)? + .apply(&self.fc1)? + .gelu()? + .apply(&self.fc2) + } +} + +#[derive(Debug)] +struct Attention { + norm: candle_nn::LayerNorm, + qkv: super::Linear, + proj: super::Linear, + ab: Tensor, + key_dim: usize, + num_heads: usize, + d: usize, + dh: usize, + scale: f64, + span: tracing::Span, + span_matmul: tracing::Span, + span_softmax: tracing::Span, +} + +impl Attention { + fn new( + dim: usize, + key_dim: usize, + num_heads: usize, + attn_ratio: usize, + resolution: (usize, usize), + vb: VarBuilder, + ) -> Result<Self> { + let d = attn_ratio * key_dim; + let dh = d * num_heads; + let nh_kd = key_dim * num_heads; + let h = dh + nh_kd * 2; + let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?; + let qkv = super::linear(vb.pp("qkv"), dim, h, true)?; + let proj = super::linear(vb.pp("proj"), dh, dim, true)?; + + let points = (0..resolution.0) + .flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64))) + .collect::<Vec<_>>(); + let mut idxs = Vec::with_capacity(points.len() * points.len()); + let mut attention_offsets = std::collections::HashMap::new(); + for &(x1, y1) in points.iter() { + for &(x2, y2) in points.iter() { + let offset = ((x2 - x1).abs(), (y2 - y1).abs()); + let l = attention_offsets.len(); + let idx = attention_offsets.entry(offset).or_insert(l); + idxs.push(*idx as u32) + } + } + let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?; + let idxs = Tensor::new(idxs, attention_biases.device())?; + let ab = + attention_biases + .index_select(&idxs, 1)? + .reshape(((), points.len(), points.len()))?; + let span = tracing::span!(tracing::Level::TRACE, "attention"); + let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul"); + let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm"); + Ok(Self { + norm, + qkv, + proj, + ab, + key_dim, + num_heads, + d, + dh, + scale: 1f64 / (key_dim as f64).sqrt(), + span, + span_matmul, + span_softmax, + }) + } +} + +impl Module for Attention { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let (b, n, _) = xs.dims3()?; + let xs = xs.apply(&self.norm)?; + let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?; + let q = qkv + .narrow(D::Minus1, 0, self.key_dim)? + .permute((0, 2, 1, 3))? + .contiguous()?; + let k = qkv + .narrow(D::Minus1, self.key_dim, self.key_dim)? + .permute((0, 2, 1, 3))? + .contiguous()?; + let v = qkv + .narrow(D::Minus1, 2 * self.key_dim, self.d)? + .permute((0, 2, 1, 3))? + .contiguous()?; + let attn = { + let _enter = self.span_matmul.enter(); + (q.matmul(&k.t()?)? * self.scale)? + }; + let attn = attn.broadcast_add(&self.ab)?; + let attn = { + let _enter = self.span_softmax.enter(); + candle_nn::ops::softmax_last_dim(&attn)? + }; + let attn = { + let _enter = self.span_matmul.enter(); + attn.matmul(&v)? + }; + attn.transpose(1, 2)? + .reshape((b, n, self.dh))? + .apply(&self.proj) + } +} + +#[derive(Debug)] +struct TinyViTBlock { + attn: Attention, + local_conv: Conv2dBN, + mlp: Mlp, + window_size: usize, + input_resolution: (usize, usize), + span: tracing::Span, +} + +impl TinyViTBlock { + fn new( + dim: usize, + input_resolution: (usize, usize), + num_heads: usize, + window_size: usize, + vb: VarBuilder, + ) -> Result<Self> { + let head_dim = dim / num_heads; + let attn = Attention::new( + dim, + head_dim, + num_heads, + 1, + (window_size, window_size), + vb.pp("attn"), + )?; + let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?; + let cfg = candle_nn::Conv2dConfig { + padding: LOCAL_CONV_SIZE / 2, + groups: dim, + ..Default::default() + }; + let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?; + let span = tracing::span!(tracing::Level::TRACE, "attention"); + Ok(Self { + attn, + local_conv, + mlp, + window_size, + input_resolution, + span, + }) + } +} + +impl Module for TinyViTBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let (h, w) = self.input_resolution; + let (b, l, c) = xs.dims3()?; + let res_x = xs; + let xs = if h == self.window_size && w == self.window_size { + self.attn.forward(xs)? + } else { + let xs = xs.reshape((b, h, w, c))?; + let pad_b = (self.window_size - h % self.window_size) % self.window_size; + let pad_r = (self.window_size - w % self.window_size) % self.window_size; + + let xs = if pad_b > 0 { + xs.pad_with_zeros(1, 0, pad_b)? + } else { + xs + }; + let xs = if pad_r > 0 { + xs.pad_with_zeros(2, 0, pad_r)? + } else { + xs + }; + let (p_h, p_w) = (h + pad_b, w + pad_r); + let n_h = p_h / self.window_size; + let n_w = p_w / self.window_size; + let xs = xs + .reshape((b, n_h, self.window_size, n_w, self.window_size, c))? + .transpose(2, 3)? + .reshape((b * n_h * n_w, self.window_size * self.window_size, c))?; + let xs = self.attn.forward(&xs)?; + let xs = xs + .reshape((b, n_h, n_w, self.window_size, self.window_size, c))? + .transpose(2, 3)? + .reshape((b, p_h, p_w, c))?; + let xs = if pad_r > 0 { + xs.i((.., .., ..w))?.contiguous()? + } else { + xs + }; + let xs = if pad_b > 0 { + xs.i((.., ..h, ..))?.contiguous()? + } else { + xs + }; + xs.reshape((b, l, c))? + }; + let xs = (xs + res_x)?; + let xs = xs + .transpose(1, 2)? + .reshape((b, c, h, w))? + .apply(&self.local_conv)? + .reshape((b, c, l))? + .transpose(1, 2)?; + &xs + self.mlp.forward(&xs)? + } +} + +#[derive(Debug)] +struct BasicLayer { + blocks: Vec<TinyViTBlock>, + downsample: Option<PatchMerging>, + span: tracing::Span, +} + +impl BasicLayer { + #[allow(clippy::too_many_arguments)] + fn new( + dim: usize, + input_resolution: (usize, usize), + depth: usize, + num_heads: usize, + window_size: usize, + downsample: bool, + out: usize, + vb: VarBuilder, + ) -> Result<Self> { + let vb_b = vb.pp("blocks"); + let mut blocks = Vec::with_capacity(depth); + for index in 0..depth { + let block = TinyViTBlock::new( + dim, + input_resolution, + num_heads, + window_size, + vb_b.pp(index), + )?; + blocks.push(block) + } + let downsample = if downsample { + let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?; + Some(downsample) + } else { + None + }; + let span = tracing::span!(tracing::Level::TRACE, "basic-layer"); + Ok(Self { + blocks, + downsample, + span, + }) + } +} + +impl Module for BasicLayer { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut xs = xs.clone(); + for block in self.blocks.iter() { + xs = block.forward(&xs)? + } + match &self.downsample { + None => Ok(xs), + Some(downsample) => downsample.forward(&xs), + } + } +} + +#[derive(Debug)] +pub struct TinyViT { + patch_embed: PatchEmbed, + layer0: ConvLayer, + layers: Vec<BasicLayer>, + // norm_head: candle_nn::LayerNorm, + // head: candle_nn::Linear, + neck_conv1: candle_nn::Conv2d, + neck_ln1: super::LayerNorm2d, + neck_conv2: candle_nn::Conv2d, + neck_ln2: super::LayerNorm2d, + span: tracing::Span, + span_neck: tracing::Span, +} + +impl TinyViT { + pub fn new( + embed_dims: &[usize], + depths: &[usize], + num_heads: &[usize], + window_sizes: &[usize], + _num_classes: usize, + vb: VarBuilder, + ) -> Result<Self> { + let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?; + let patches_resolution = IMG_SIZE / 4; + + let vb_l = vb.pp("layers"); + let layer0 = ConvLayer::new( + /* dim */ embed_dims[0], + /* out */ embed_dims[1], + /* input_resolution */ (patches_resolution, patches_resolution), + /* depth */ depths[0], + /* downsample */ true, + /* conv_expand_ratio */ MBCONV_EXPAND_RATIO, + vb_l.pp(0), + )?; + + let num_layers = embed_dims.len(); + let mut layers = Vec::with_capacity(num_layers - 1); + for i_layer in 1..num_layers { + let patches_resolution = patches_resolution / (1 << usize::min(i_layer, 2)); + let layer = BasicLayer::new( + /* dim */ embed_dims[i_layer], + /* input_resolution */ (patches_resolution, patches_resolution), + /* depth */ depths[i_layer], + /* num_heads */ num_heads[i_layer], + /* window_size */ window_sizes[i_layer], + /* downsample */ i_layer < num_layers - 1, + /* out */ embed_dims[usize::min(i_layer + 1, num_layers - 1)], + vb_l.pp(i_layer), + )?; + layers.push(layer) + } + + let last_embed_dim = embed_dims[embed_dims.len() - 1]; + // let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?; + // let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?; + let neck_conv1 = + candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?; + let neck_ln1 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?; + let cfg = candle_nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?; + let neck_ln2 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?; + + let span = tracing::span!(tracing::Level::TRACE, "tiny-vit"); + let span_neck = tracing::span!(tracing::Level::TRACE, "neck"); + Ok(Self { + patch_embed, + layer0, + layers, + neck_conv1, + neck_ln1, + neck_conv2, + neck_ln2, + span, + span_neck, + }) + } +} + +impl Module for TinyViT { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let xs = self.patch_embed.forward(xs)?; + let mut xs = self.layer0.forward(&xs)?; + for layer in self.layers.iter() { + xs = layer.forward(&xs)? + } + let (b, _, c) = xs.dims3()?; + let _enter = self.span_neck.enter(); + xs.reshape((b, 64, 64, c))? + .permute((0, 3, 1, 2))? + .apply(&self.neck_conv1)? + .apply(&self.neck_ln1)? + .apply(&self.neck_conv2)? + .apply(&self.neck_ln2) + } +} + +pub fn tiny_vit_5m(vb: VarBuilder) -> Result<TinyViT> { + TinyViT::new( + /* embed_dims */ &[64, 128, 160, 320], + /* depths */ &[2, 2, 6, 2], + /* num_heads */ &[2, 4, 5, 10], + /* window_sizes */ &[7, 7, 14, 7], + /* num_classes */ 1000, + vb, + ) +} diff --git a/candle-transformers/src/models/segment_anything/transformer.rs b/candle-transformers/src/models/segment_anything/transformer.rs new file mode 100644 index 00000000..80efb38c --- /dev/null +++ b/candle-transformers/src/models/segment_anything/transformer.rs @@ -0,0 +1,221 @@ +use candle::{Result, Tensor}; +use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; + +#[derive(Debug)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + out_proj: Linear, + num_heads: usize, +} + +impl Attention { + fn new( + embedding_dim: usize, + num_heads: usize, + downsample_rate: usize, + vb: VarBuilder, + ) -> Result<Self> { + let internal_dim = embedding_dim / downsample_rate; + let q_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("q_proj"))?; + let k_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("k_proj"))?; + let v_proj = candle_nn::linear(embedding_dim, internal_dim, vb.pp("v_proj"))?; + let out_proj = candle_nn::linear(internal_dim, embedding_dim, vb.pp("out_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + out_proj, + num_heads, + }) + } + + fn separate_heads(&self, x: &Tensor) -> Result<Tensor> { + let (b, n, c) = x.dims3()?; + x.reshape((b, n, self.num_heads, c / self.num_heads))? + .transpose(1, 2)? + .contiguous() + } + + fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> { + let (b, n_heads, n_tokens, c_per_head) = x.dims4()?; + x.transpose(1, 2)? + .reshape((b, n_tokens, n_heads * c_per_head)) + } + + fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> { + let q = self.q_proj.forward(&q.contiguous()?)?; + let k = self.k_proj.forward(&k.contiguous()?)?; + let v = self.v_proj.forward(&v.contiguous()?)?; + + let q = self.separate_heads(&q)?; + let k = self.separate_heads(&k)?; + let v = self.separate_heads(&v)?; + + let (_, _, _, c_per_head) = q.dims4()?; + let attn = (q.matmul(&k.t()?)? / (c_per_head as f64).sqrt())?; + let attn = candle_nn::ops::softmax_last_dim(&attn)?; + + let out = attn.matmul(&v)?; + self.recombine_heads(&out)?.apply(&self.out_proj) + } +} + +#[derive(Debug)] +struct TwoWayAttentionBlock { + self_attn: Attention, + norm1: LayerNorm, + cross_attn_token_to_image: Attention, + norm2: LayerNorm, + mlp: super::MlpBlock, + norm3: LayerNorm, + norm4: LayerNorm, + cross_attn_image_to_token: Attention, + skip_first_layer_pe: bool, +} + +impl TwoWayAttentionBlock { + fn new( + embedding_dim: usize, + num_heads: usize, + mlp_dim: usize, + skip_first_layer_pe: bool, + vb: VarBuilder, + ) -> Result<Self> { + let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?; + let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?; + let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?; + let norm4 = layer_norm(embedding_dim, 1e-5, vb.pp("norm4"))?; + let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?; + let cross_attn_token_to_image = Attention::new( + embedding_dim, + num_heads, + 2, + vb.pp("cross_attn_token_to_image"), + )?; + let cross_attn_image_to_token = Attention::new( + embedding_dim, + num_heads, + 2, + vb.pp("cross_attn_image_to_token"), + )?; + let mlp = super::MlpBlock::new( + embedding_dim, + mlp_dim, + candle_nn::Activation::Relu, + vb.pp("mlp"), + )?; + Ok(Self { + self_attn, + norm1, + cross_attn_image_to_token, + norm2, + mlp, + norm3, + norm4, + cross_attn_token_to_image, + skip_first_layer_pe, + }) + } + + fn forward( + &self, + queries: &Tensor, + keys: &Tensor, + query_pe: &Tensor, + key_pe: &Tensor, + ) -> Result<(Tensor, Tensor)> { + // Self attention block + let queries = if self.skip_first_layer_pe { + self.self_attn.forward(queries, queries, queries)? + } else { + let q = (queries + query_pe)?; + let attn_out = self.self_attn.forward(&q, &q, queries)?; + (queries + attn_out)? + }; + let queries = self.norm1.forward(&queries)?; + + // Cross attention block, tokens attending to image embedding + let q = (&queries + query_pe)?; + let k = (keys + key_pe)?; + let attn_out = self.cross_attn_token_to_image.forward(&q, &k, keys)?; + let queries = (&queries + attn_out)?; + let queries = self.norm2.forward(&queries)?; + + // MLP block + let mlp_out = self.mlp.forward(&queries); + let queries = (queries + mlp_out)?; + let queries = self.norm3.forward(&queries)?; + + // Cross attention block, image embedding attending to tokens + let q = (&queries + query_pe)?; + let k = (keys + key_pe)?; + let attn_out = self.cross_attn_image_to_token.forward(&k, &q, &queries)?; + let keys = (keys + attn_out)?; + let keys = self.norm4.forward(&keys)?; + + Ok((queries, keys)) + } +} + +#[derive(Debug)] +pub struct TwoWayTransformer { + layers: Vec<TwoWayAttentionBlock>, + final_attn_token_to_image: Attention, + norm_final_attn: LayerNorm, +} + +impl TwoWayTransformer { + pub fn new( + depth: usize, + embedding_dim: usize, + num_heads: usize, + mlp_dim: usize, + vb: VarBuilder, + ) -> Result<Self> { + let vb_l = vb.pp("layers"); + let mut layers = Vec::with_capacity(depth); + for i in 0..depth { + let layer = + TwoWayAttentionBlock::new(embedding_dim, num_heads, mlp_dim, i == 0, vb_l.pp(i))?; + layers.push(layer) + } + let final_attn_token_to_image = Attention::new( + embedding_dim, + num_heads, + 2, + vb.pp("final_attn_token_to_image"), + )?; + let norm_final_attn = layer_norm(embedding_dim, 1e-5, vb.pp("norm_final_attn"))?; + Ok(Self { + layers, + final_attn_token_to_image, + norm_final_attn, + }) + } + + pub fn forward( + &self, + image_embedding: &Tensor, + image_pe: &Tensor, + point_embedding: &Tensor, + ) -> Result<(Tensor, Tensor)> { + let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?; + let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?; + + let mut queries = point_embedding.clone(); + let mut keys = image_embedding; + + for layer in self.layers.iter() { + (queries, keys) = layer.forward(&queries, &keys, point_embedding, &image_pe)? + } + + let q = (&queries + point_embedding)?; + let k = (&keys + image_pe)?; + let attn_out = self.final_attn_token_to_image.forward(&q, &k, &keys)?; + let queries = (queries + attn_out)?.apply(&self.norm_final_attn)?; + + Ok((queries, keys)) + } +} diff --git a/candle-transformers/src/models/stable_diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs new file mode 100644 index 00000000..b3ea91f9 --- /dev/null +++ b/candle-transformers/src/models/stable_diffusion/attention.rs @@ -0,0 +1,547 @@ +//! Attention Based Building Blocks +use candle::{DType, IndexOp, Result, Tensor, D}; +use candle_nn as nn; +use candle_nn::Module; + +#[derive(Debug)] +struct GeGlu { + proj: nn::Linear, + span: tracing::Span, +} + +impl GeGlu { + fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> { + let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?; + let span = tracing::span!(tracing::Level::TRACE, "geglu"); + Ok(Self { proj, span }) + } +} + +impl Module for GeGlu { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?; + &hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()? + } +} + +/// A feed-forward layer. +#[derive(Debug)] +struct FeedForward { + project_in: GeGlu, + linear: nn::Linear, + span: tracing::Span, +} + +impl FeedForward { + // The glu parameter in the python code is unused? + // https://github.com/huggingface/diffusers/blob/d3d22ce5a894becb951eec03e663951b28d45135/src/diffusers/models/attention.py#L347 + /// Creates a new feed-forward layer based on some given input dimension, some + /// output dimension, and a multiplier to be used for the intermediary layer. + fn new(vs: nn::VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> { + let inner_dim = dim * mult; + let dim_out = dim_out.unwrap_or(dim); + let vs = vs.pp("net"); + let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?; + let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?; + let span = tracing::span!(tracing::Level::TRACE, "ff"); + Ok(Self { + project_in, + linear, + span, + }) + } +} + +impl Module for FeedForward { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let xs = self.project_in.forward(xs)?; + self.linear.forward(&xs) + } +} + +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result<Tensor> { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> { + unimplemented!("compile with '--features flash-attn'") +} + +#[derive(Debug)] +pub struct CrossAttention { + to_q: nn::Linear, + to_k: nn::Linear, + to_v: nn::Linear, + to_out: nn::Linear, + heads: usize, + scale: f64, + slice_size: Option<usize>, + span: tracing::Span, + span_attn: tracing::Span, + span_softmax: tracing::Span, + use_flash_attn: bool, +} + +impl CrossAttention { + // Defaults should be heads = 8, dim_head = 64, context_dim = None + pub fn new( + vs: nn::VarBuilder, + query_dim: usize, + context_dim: Option<usize>, + heads: usize, + dim_head: usize, + slice_size: Option<usize>, + use_flash_attn: bool, + ) -> Result<Self> { + let inner_dim = dim_head * heads; + let context_dim = context_dim.unwrap_or(query_dim); + let scale = 1.0 / f64::sqrt(dim_head as f64); + let to_q = nn::linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?; + let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?; + let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?; + let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?; + let span = tracing::span!(tracing::Level::TRACE, "xa"); + let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn"); + let span_softmax = tracing::span!(tracing::Level::TRACE, "xa-softmax"); + Ok(Self { + to_q, + to_k, + to_v, + to_out, + heads, + scale, + slice_size, + span, + span_attn, + span_softmax, + use_flash_attn, + }) + } + + fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> { + let (batch_size, seq_len, dim) = xs.dims3()?; + xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))? + .transpose(1, 2)? + .reshape((batch_size * self.heads, seq_len, dim / self.heads)) + } + + fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> { + let (batch_size, seq_len, dim) = xs.dims3()?; + xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))? + .transpose(1, 2)? + .reshape((batch_size / self.heads, seq_len, dim * self.heads)) + } + + fn sliced_attention( + &self, + query: &Tensor, + key: &Tensor, + value: &Tensor, + slice_size: usize, + ) -> Result<Tensor> { + let batch_size_attention = query.dim(0)?; + let mut hidden_states = Vec::with_capacity(batch_size_attention / slice_size); + let in_dtype = query.dtype(); + let query = query.to_dtype(DType::F32)?; + let key = key.to_dtype(DType::F32)?; + let value = value.to_dtype(DType::F32)?; + + for i in 0..batch_size_attention / slice_size { + let start_idx = i * slice_size; + let end_idx = (i + 1) * slice_size; + + let xs = query + .i(start_idx..end_idx)? + .matmul(&(key.i(start_idx..end_idx)?.t()? * self.scale)?)?; + let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(&value.i(start_idx..end_idx)?)?; + hidden_states.push(xs) + } + let hidden_states = Tensor::stack(&hidden_states, 0)?.to_dtype(in_dtype)?; + self.reshape_batch_dim_to_heads(&hidden_states) + } + + fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> { + let _enter = self.span_attn.enter(); + let xs = if self.use_flash_attn { + let init_dtype = query.dtype(); + let q = query + .to_dtype(candle::DType::F16)? + .unsqueeze(0)? + .transpose(1, 2)?; + let k = key + .to_dtype(candle::DType::F16)? + .unsqueeze(0)? + .transpose(1, 2)?; + let v = value + .to_dtype(candle::DType::F16)? + .unsqueeze(0)? + .transpose(1, 2)?; + flash_attn(&q, &k, &v, self.scale as f32, false)? + .transpose(1, 2)? + .squeeze(0)? + .to_dtype(init_dtype)? + } else { + let in_dtype = query.dtype(); + let query = query.to_dtype(DType::F32)?; + let key = key.to_dtype(DType::F32)?; + let value = value.to_dtype(DType::F32)?; + let xs = query.matmul(&(key.t()? * self.scale)?)?; + let xs = { + let _enter = self.span_softmax.enter(); + nn::ops::softmax_last_dim(&xs)? + }; + xs.matmul(&value)?.to_dtype(in_dtype)? + }; + self.reshape_batch_dim_to_heads(&xs) + } + + pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> { + let _enter = self.span.enter(); + let query = self.to_q.forward(xs)?; + let context = context.unwrap_or(xs).contiguous()?; + let key = self.to_k.forward(&context)?; + let value = self.to_v.forward(&context)?; + let query = self.reshape_heads_to_batch_dim(&query)?; + let key = self.reshape_heads_to_batch_dim(&key)?; + let value = self.reshape_heads_to_batch_dim(&value)?; + let dim0 = query.dim(0)?; + let slice_size = self.slice_size.and_then(|slice_size| { + if dim0 < slice_size { + None + } else { + Some(slice_size) + } + }); + let xs = match slice_size { + None => self.attention(&query, &key, &value)?, + Some(slice_size) => self.sliced_attention(&query, &key, &value, slice_size)?, + }; + self.to_out.forward(&xs) + } +} + +/// A basic Transformer block. +#[derive(Debug)] +struct BasicTransformerBlock { + attn1: CrossAttention, + ff: FeedForward, + attn2: CrossAttention, + norm1: nn::LayerNorm, + norm2: nn::LayerNorm, + norm3: nn::LayerNorm, + span: tracing::Span, +} + +impl BasicTransformerBlock { + fn new( + vs: nn::VarBuilder, + dim: usize, + n_heads: usize, + d_head: usize, + context_dim: Option<usize>, + sliced_attention_size: Option<usize>, + use_flash_attn: bool, + ) -> Result<Self> { + let attn1 = CrossAttention::new( + vs.pp("attn1"), + dim, + None, + n_heads, + d_head, + sliced_attention_size, + use_flash_attn, + )?; + let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?; + let attn2 = CrossAttention::new( + vs.pp("attn2"), + dim, + context_dim, + n_heads, + d_head, + sliced_attention_size, + use_flash_attn, + )?; + let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?; + let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?; + let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?; + let span = tracing::span!(tracing::Level::TRACE, "basic-transformer"); + Ok(Self { + attn1, + ff, + attn2, + norm1, + norm2, + norm3, + span, + }) + } + + fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> { + let _enter = self.span.enter(); + let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?; + let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?; + self.ff.forward(&self.norm3.forward(&xs)?)? + xs + } +} + +#[derive(Debug, Clone, Copy)] +pub struct SpatialTransformerConfig { + pub depth: usize, + pub num_groups: usize, + pub context_dim: Option<usize>, + pub sliced_attention_size: Option<usize>, + pub use_linear_projection: bool, +} + +impl Default for SpatialTransformerConfig { + fn default() -> Self { + Self { + depth: 1, + num_groups: 32, + context_dim: None, + sliced_attention_size: None, + use_linear_projection: false, + } + } +} + +#[derive(Debug)] +enum Proj { + Conv2d(nn::Conv2d), + Linear(nn::Linear), +} + +// Aka Transformer2DModel +#[derive(Debug)] +pub struct SpatialTransformer { + norm: nn::GroupNorm, + proj_in: Proj, + transformer_blocks: Vec<BasicTransformerBlock>, + proj_out: Proj, + span: tracing::Span, + pub config: SpatialTransformerConfig, +} + +impl SpatialTransformer { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + n_heads: usize, + d_head: usize, + use_flash_attn: bool, + config: SpatialTransformerConfig, + ) -> Result<Self> { + let inner_dim = n_heads * d_head; + let norm = nn::group_norm(config.num_groups, in_channels, 1e-6, vs.pp("norm"))?; + let proj_in = if config.use_linear_projection { + Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_in"))?) + } else { + Proj::Conv2d(nn::conv2d( + in_channels, + inner_dim, + 1, + Default::default(), + vs.pp("proj_in"), + )?) + }; + let mut transformer_blocks = vec![]; + let vs_tb = vs.pp("transformer_blocks"); + for index in 0..config.depth { + let tb = BasicTransformerBlock::new( + vs_tb.pp(&index.to_string()), + inner_dim, + n_heads, + d_head, + config.context_dim, + config.sliced_attention_size, + use_flash_attn, + )?; + transformer_blocks.push(tb) + } + let proj_out = if config.use_linear_projection { + Proj::Linear(nn::linear(in_channels, inner_dim, vs.pp("proj_out"))?) + } else { + Proj::Conv2d(nn::conv2d( + inner_dim, + in_channels, + 1, + Default::default(), + vs.pp("proj_out"), + )?) + }; + let span = tracing::span!(tracing::Level::TRACE, "spatial-transformer"); + Ok(Self { + norm, + proj_in, + transformer_blocks, + proj_out, + span, + config, + }) + } + + pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> { + let _enter = self.span.enter(); + let (batch, _channel, height, weight) = xs.dims4()?; + let residual = xs; + let xs = self.norm.forward(xs)?; + let (inner_dim, xs) = match &self.proj_in { + Proj::Conv2d(p) => { + let xs = p.forward(&xs)?; + let inner_dim = xs.dim(1)?; + let xs = xs + .transpose(1, 2)? + .t()? + .reshape((batch, height * weight, inner_dim))?; + (inner_dim, xs) + } + Proj::Linear(p) => { + let inner_dim = xs.dim(1)?; + let xs = xs + .transpose(1, 2)? + .t()? + .reshape((batch, height * weight, inner_dim))?; + (inner_dim, p.forward(&xs)?) + } + }; + let mut xs = xs; + for block in self.transformer_blocks.iter() { + xs = block.forward(&xs, context)? + } + let xs = match &self.proj_out { + Proj::Conv2d(p) => p.forward( + &xs.reshape((batch, height, weight, inner_dim))? + .t()? + .transpose(1, 2)?, + )?, + Proj::Linear(p) => p + .forward(&xs)? + .reshape((batch, height, weight, inner_dim))? + .t()? + .transpose(1, 2)?, + }; + xs + residual + } +} + +/// Configuration for an attention block. +#[derive(Debug, Clone, Copy)] +pub struct AttentionBlockConfig { + pub num_head_channels: Option<usize>, + pub num_groups: usize, + pub rescale_output_factor: f64, + pub eps: f64, +} + +impl Default for AttentionBlockConfig { + fn default() -> Self { + Self { + num_head_channels: None, + num_groups: 32, + rescale_output_factor: 1., + eps: 1e-5, + } + } +} + +#[derive(Debug)] +pub struct AttentionBlock { + group_norm: nn::GroupNorm, + query: nn::Linear, + key: nn::Linear, + value: nn::Linear, + proj_attn: nn::Linear, + channels: usize, + num_heads: usize, + span: tracing::Span, + config: AttentionBlockConfig, +} + +impl AttentionBlock { + pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result<Self> { + let num_head_channels = config.num_head_channels.unwrap_or(channels); + let num_heads = channels / num_head_channels; + let group_norm = + nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?; + let (q_path, k_path, v_path, out_path) = if vs.contains_tensor("to_q.weight") { + ("to_q", "to_k", "to_v", "to_out.0") + } else { + ("query", "key", "value", "proj_attn") + }; + let query = nn::linear(channels, channels, vs.pp(q_path))?; + let key = nn::linear(channels, channels, vs.pp(k_path))?; + let value = nn::linear(channels, channels, vs.pp(v_path))?; + let proj_attn = nn::linear(channels, channels, vs.pp(out_path))?; + let span = tracing::span!(tracing::Level::TRACE, "attn-block"); + Ok(Self { + group_norm, + query, + key, + value, + proj_attn, + channels, + num_heads, + span, + config, + }) + } + + fn transpose_for_scores(&self, xs: Tensor) -> Result<Tensor> { + let (batch, t, h_times_d) = xs.dims3()?; + xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))? + .transpose(1, 2) + } +} + +impl Module for AttentionBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let in_dtype = xs.dtype(); + let residual = xs; + let (batch, channel, height, width) = xs.dims4()?; + let xs = self + .group_norm + .forward(xs)? + .reshape((batch, channel, height * width))? + .transpose(1, 2)?; + + let query_proj = self.query.forward(&xs)?; + let key_proj = self.key.forward(&xs)?; + let value_proj = self.value.forward(&xs)?; + + let query_states = self + .transpose_for_scores(query_proj)? + .to_dtype(DType::F32)?; + let key_states = self.transpose_for_scores(key_proj)?.to_dtype(DType::F32)?; + let value_states = self + .transpose_for_scores(value_proj)? + .to_dtype(DType::F32)?; + + let scale = f64::powf((self.channels as f64) / (self.num_heads as f64), -0.25); + let attention_scores = + // TODO: Check that this needs two multiplication by `scale`. + (query_states * scale)?.matmul(&(key_states.t()? * scale)?)?; + let attention_probs = nn::ops::softmax(&attention_scores, D::Minus1)?; + + let xs = attention_probs.matmul(&value_states.contiguous()?)?; + let xs = xs.to_dtype(in_dtype)?; + let xs = xs.transpose(1, 2)?.contiguous()?; + let xs = xs.flatten_from(D::Minus2)?; + let xs = self + .proj_attn + .forward(&xs)? + .t()? + .reshape((batch, channel, height, width))?; + (xs + residual)? / self.config.rescale_output_factor + } +} diff --git a/candle-transformers/src/models/stable_diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs new file mode 100644 index 00000000..e7a20270 --- /dev/null +++ b/candle-transformers/src/models/stable_diffusion/clip.rs @@ -0,0 +1,389 @@ +//! Contrastive Language-Image Pre-Training +//! +//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on +//! pairs of images with related texts. +//! +//! https://github.com/openai/CLIP +use candle::{DType, Device, Result, Tensor, D}; +use candle_nn as nn; +use candle_nn::Module; + +#[derive(Debug, Clone, Copy)] +pub enum Activation { + QuickGelu, + Gelu, + GeluErf, +} + +impl Module for Activation { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + match self { + Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?, + Activation::Gelu => xs.gelu(), + Activation::GeluErf => xs.gelu_erf(), + } + } +} + +#[derive(Debug, Clone)] +pub struct Config { + vocab_size: usize, + embed_dim: usize, // aka config.hidden_size + activation: Activation, // aka config.hidden_act + intermediate_size: usize, + pub max_position_embeddings: usize, + // The character to use for padding, use EOS when not set. + pub pad_with: Option<String>, + num_hidden_layers: usize, + num_attention_heads: usize, + #[allow(dead_code)] + projection_dim: usize, +} + +impl Config { + // The config details can be found in the "text_config" section of this json file: + // https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json + pub fn v1_5() -> Self { + Self { + vocab_size: 49408, + embed_dim: 768, + intermediate_size: 3072, + max_position_embeddings: 77, + pad_with: None, + num_hidden_layers: 12, + num_attention_heads: 12, + projection_dim: 768, + activation: Activation::QuickGelu, + } + } + + // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/text_encoder/config.json + pub fn v2_1() -> Self { + Self { + vocab_size: 49408, + embed_dim: 1024, + intermediate_size: 4096, + max_position_embeddings: 77, + pad_with: Some("!".to_string()), + num_hidden_layers: 23, + num_attention_heads: 16, + projection_dim: 512, + activation: Activation::Gelu, + } + } + + // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder/config.json + pub fn sdxl() -> Self { + Self { + vocab_size: 49408, + embed_dim: 768, + intermediate_size: 3072, + max_position_embeddings: 77, + pad_with: Some("!".to_string()), + num_hidden_layers: 12, + num_attention_heads: 12, + projection_dim: 768, + activation: Activation::QuickGelu, + } + } + + // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder_2/config.json + pub fn sdxl2() -> Self { + Self { + vocab_size: 49408, + embed_dim: 1280, + intermediate_size: 5120, + max_position_embeddings: 77, + pad_with: Some("!".to_string()), + num_hidden_layers: 32, + num_attention_heads: 20, + projection_dim: 1280, + activation: Activation::Gelu, + } + } + + // https://huggingface.co/warp-ai/wuerstchen/blob/main/text_encoder/config.json + pub fn wuerstchen() -> Self { + Self { + vocab_size: 49408, + embed_dim: 1024, + intermediate_size: 4096, + max_position_embeddings: 77, + pad_with: None, + num_hidden_layers: 24, + num_attention_heads: 16, + projection_dim: 1024, + activation: Activation::GeluErf, + } + } + + // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/text_encoder/config.json + pub fn wuerstchen_prior() -> Self { + Self { + vocab_size: 49408, + embed_dim: 1280, + intermediate_size: 5120, + max_position_embeddings: 77, + pad_with: None, + num_hidden_layers: 32, + num_attention_heads: 20, + projection_dim: 512, + activation: Activation::GeluErf, + } + } +} + +// CLIP Text Model +// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py +#[derive(Debug)] +struct ClipTextEmbeddings { + token_embedding: candle_nn::Embedding, + position_embedding: candle_nn::Embedding, + position_ids: Tensor, +} + +impl ClipTextEmbeddings { + fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> { + let token_embedding = + candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?; + let position_embedding = candle_nn::embedding( + c.max_position_embeddings, + c.embed_dim, + vs.pp("position_embedding"), + )?; + let position_ids = + Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?; + Ok(ClipTextEmbeddings { + token_embedding, + position_embedding, + position_ids, + }) + } +} + +impl Module for ClipTextEmbeddings { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let token_embedding = self.token_embedding.forward(xs)?; + let position_embedding = self.position_embedding.forward(&self.position_ids)?; + token_embedding.broadcast_add(&position_embedding) + } +} + +#[derive(Debug)] +struct ClipAttention { + k_proj: candle_nn::Linear, + v_proj: candle_nn::Linear, + q_proj: candle_nn::Linear, + out_proj: candle_nn::Linear, + head_dim: usize, + scale: f64, + num_attention_heads: usize, +} + +impl ClipAttention { + fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> { + let embed_dim = c.embed_dim; + let num_attention_heads = c.num_attention_heads; + let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?; + let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?; + let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?; + let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?; + let head_dim = embed_dim / num_attention_heads; + let scale = (head_dim as f64).powf(-0.5); + Ok(ClipAttention { + k_proj, + v_proj, + q_proj, + out_proj, + head_dim, + scale, + num_attention_heads, + }) + } + + fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> { + xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous() + } + + fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> { + let in_dtype = xs.dtype(); + let (bsz, seq_len, embed_dim) = xs.dims3()?; + let query_states = (self.q_proj.forward(xs)? * self.scale)?; + let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim); + let query_states = self + .shape(&query_states, seq_len, bsz)? + .reshape(proj_shape)? + .to_dtype(DType::F32)?; + let key_states = self + .shape(&self.k_proj.forward(xs)?, seq_len, bsz)? + .reshape(proj_shape)? + .to_dtype(DType::F32)?; + let value_states = self + .shape(&self.v_proj.forward(xs)?, seq_len, bsz)? + .reshape(proj_shape)? + .to_dtype(DType::F32)?; + let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?; + + let src_len = key_states.dim(1)?; + let attn_weights = attn_weights + .reshape((bsz, self.num_attention_heads, seq_len, src_len))? + .broadcast_add(causal_attention_mask)?; + let attn_weights = + attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?; + let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; + + let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?; + let attn_output = attn_output + .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))? + .transpose(1, 2)? + .reshape((bsz, seq_len, embed_dim))?; + self.out_proj.forward(&attn_output) + } +} + +#[derive(Debug)] +struct ClipMlp { + fc1: candle_nn::Linear, + fc2: candle_nn::Linear, + activation: Activation, +} + +impl ClipMlp { + fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> { + let fc1 = candle_nn::linear(c.embed_dim, c.intermediate_size, vs.pp("fc1"))?; + let fc2 = candle_nn::linear(c.intermediate_size, c.embed_dim, vs.pp("fc2"))?; + Ok(ClipMlp { + fc1, + fc2, + activation: c.activation, + }) + } +} + +impl ClipMlp { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.fc1.forward(xs)?; + self.fc2.forward(&self.activation.forward(&xs)?) + } +} + +#[derive(Debug)] +struct ClipEncoderLayer { + self_attn: ClipAttention, + layer_norm1: candle_nn::LayerNorm, + mlp: ClipMlp, + layer_norm2: candle_nn::LayerNorm, +} + +impl ClipEncoderLayer { + fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> { + let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?; + let layer_norm1 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm1"))?; + let mlp = ClipMlp::new(vs.pp("mlp"), c)?; + let layer_norm2 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm2"))?; + Ok(ClipEncoderLayer { + self_attn, + layer_norm1, + mlp, + layer_norm2, + }) + } + + fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> { + let residual = xs; + let xs = self.layer_norm1.forward(xs)?; + let xs = self.self_attn.forward(&xs, causal_attention_mask)?; + let xs = (xs + residual)?; + + let residual = &xs; + let xs = self.layer_norm2.forward(&xs)?; + let xs = self.mlp.forward(&xs)?; + xs + residual + } +} + +#[derive(Debug)] +struct ClipEncoder { + layers: Vec<ClipEncoderLayer>, +} + +impl ClipEncoder { + fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> { + let vs = vs.pp("layers"); + let mut layers: Vec<ClipEncoderLayer> = Vec::new(); + for index in 0..c.num_hidden_layers { + let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?; + layers.push(layer) + } + Ok(ClipEncoder { layers }) + } + + fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> { + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs, causal_attention_mask)?; + } + Ok(xs) + } +} + +/// A CLIP transformer based model. +#[derive(Debug)] +pub struct ClipTextTransformer { + embeddings: ClipTextEmbeddings, + encoder: ClipEncoder, + final_layer_norm: candle_nn::LayerNorm, +} + +impl ClipTextTransformer { + pub fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> { + let vs = vs.pp("text_model"); + let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?; + let encoder = ClipEncoder::new(vs.pp("encoder"), c)?; + let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?; + Ok(ClipTextTransformer { + embeddings, + encoder, + final_layer_norm, + }) + } + + // https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678 + fn build_causal_attention_mask( + bsz: usize, + seq_len: usize, + mask_after: usize, + device: &Device, + ) -> Result<Tensor> { + let mask: Vec<_> = (0..seq_len) + .flat_map(|i| { + (0..seq_len).map(move |j| { + if j > i || j > mask_after { + f32::MIN + } else { + 0. + } + }) + }) + .collect(); + let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?; + mask.broadcast_as((bsz, seq_len, seq_len)) + } + + pub fn forward_with_mask(&self, xs: &Tensor, mask_after: usize) -> Result<Tensor> { + let (bsz, seq_len) = xs.dims2()?; + let xs = self.embeddings.forward(xs)?; + let causal_attention_mask = + Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?; + let xs = self.encoder.forward(&xs, &causal_attention_mask)?; + self.final_layer_norm.forward(&xs) + } +} + +impl Module for ClipTextTransformer { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + self.forward_with_mask(xs, usize::MAX) + } +} diff --git a/candle-transformers/src/models/stable_diffusion/ddim.rs b/candle-transformers/src/models/stable_diffusion/ddim.rs new file mode 100644 index 00000000..916b7349 --- /dev/null +++ b/candle-transformers/src/models/stable_diffusion/ddim.rs @@ -0,0 +1,180 @@ +//! # Denoising Diffusion Implicit Models +//! +//! The Denoising Diffusion Implicit Models (DDIM) is a simple scheduler +//! similar to Denoising Diffusion Probabilistic Models (DDPM). The DDPM +//! generative process is the reverse of a Markovian process, DDIM generalizes +//! this to non-Markovian guidance. +//! +//! Denoising Diffusion Implicit Models, J. Song et al, 2020. +//! https://arxiv.org/abs/2010.02502 +use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType}; +use candle::{Result, Tensor}; + +/// The configuration for the DDIM scheduler. +#[derive(Debug, Clone, Copy)] +pub struct DDIMSchedulerConfig { + /// The value of beta at the beginning of training. + pub beta_start: f64, + /// The value of beta at the end of training. + pub beta_end: f64, + /// How beta evolved during training. + pub beta_schedule: BetaSchedule, + /// The amount of noise to be added at each step. + pub eta: f64, + /// Adjust the indexes of the inference schedule by this value. + pub steps_offset: usize, + /// prediction type of the scheduler function, one of `epsilon` (predicting + /// the noise of the diffusion process), `sample` (directly predicting the noisy sample`) + /// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + pub prediction_type: PredictionType, + /// number of diffusion steps used to train the model + pub train_timesteps: usize, +} + +impl Default for DDIMSchedulerConfig { + fn default() -> Self { + Self { + beta_start: 0.00085f64, + beta_end: 0.012f64, + beta_schedule: BetaSchedule::ScaledLinear, + eta: 0., + steps_offset: 1, + prediction_type: PredictionType::Epsilon, + train_timesteps: 1000, + } + } +} + +/// The DDIM scheduler. +#[derive(Debug, Clone)] +pub struct DDIMScheduler { + timesteps: Vec<usize>, + alphas_cumprod: Vec<f64>, + step_ratio: usize, + init_noise_sigma: f64, + pub config: DDIMSchedulerConfig, +} + +// clip_sample: False, set_alpha_to_one: False +impl DDIMScheduler { + /// Creates a new DDIM scheduler given the number of steps to be + /// used for inference as well as the number of steps that was used + /// during training. + pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> { + let step_ratio = config.train_timesteps / inference_steps; + let timesteps: Vec<usize> = (0..(inference_steps)) + .map(|s| s * step_ratio + config.steps_offset) + .rev() + .collect(); + let betas = match config.beta_schedule { + BetaSchedule::ScaledLinear => super::utils::linspace( + config.beta_start.sqrt(), + config.beta_end.sqrt(), + config.train_timesteps, + )? + .sqr()?, + BetaSchedule::Linear => { + super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)? + } + BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?, + }; + let betas = betas.to_vec1::<f64>()?; + let mut alphas_cumprod = Vec::with_capacity(betas.len()); + for &beta in betas.iter() { + let alpha = 1.0 - beta; + alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64)) + } + Ok(Self { + alphas_cumprod, + timesteps, + step_ratio, + init_noise_sigma: 1., + config, + }) + } + + pub fn timesteps(&self) -> &[usize] { + self.timesteps.as_slice() + } + + /// Ensures interchangeability with schedulers that need to scale the denoising model input + /// depending on the current timestep. + pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> { + Ok(sample) + } + + /// Performs a backward step during inference. + pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> { + let timestep = if timestep >= self.alphas_cumprod.len() { + timestep - 1 + } else { + timestep + }; + // https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195 + let prev_timestep = if timestep > self.step_ratio { + timestep - self.step_ratio + } else { + 0 + }; + + let alpha_prod_t = self.alphas_cumprod[timestep]; + let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]; + let beta_prod_t = 1. - alpha_prod_t; + let beta_prod_t_prev = 1. - alpha_prod_t_prev; + + let (pred_original_sample, pred_epsilon) = match self.config.prediction_type { + PredictionType::Epsilon => { + let pred_original_sample = ((sample - (model_output * beta_prod_t.sqrt())?)? + * (1. / alpha_prod_t.sqrt()))?; + (pred_original_sample, model_output.clone()) + } + PredictionType::VPrediction => { + let pred_original_sample = + ((sample * alpha_prod_t.sqrt())? - (model_output * beta_prod_t.sqrt())?)?; + let pred_epsilon = + ((model_output * alpha_prod_t.sqrt())? + (sample * beta_prod_t.sqrt())?)?; + (pred_original_sample, pred_epsilon) + } + PredictionType::Sample => { + let pred_original_sample = model_output.clone(); + let pred_epsilon = ((sample - &pred_original_sample * alpha_prod_t.sqrt())? + * (1. / beta_prod_t.sqrt()))?; + (pred_original_sample, pred_epsilon) + } + }; + + let variance = (beta_prod_t_prev / beta_prod_t) * (1. - alpha_prod_t / alpha_prod_t_prev); + let std_dev_t = self.config.eta * variance.sqrt(); + + let pred_sample_direction = + (pred_epsilon * (1. - alpha_prod_t_prev - std_dev_t * std_dev_t).sqrt())?; + let prev_sample = + ((pred_original_sample * alpha_prod_t_prev.sqrt())? + pred_sample_direction)?; + if self.config.eta > 0. { + &prev_sample + + Tensor::randn( + 0f32, + std_dev_t as f32, + prev_sample.shape(), + prev_sample.device(), + )? + } else { + Ok(prev_sample) + } + } + + pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> { + let timestep = if timestep >= self.alphas_cumprod.len() { + timestep - 1 + } else { + timestep + }; + let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt(); + let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt(); + (original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)? + } + + pub fn init_noise_sigma(&self) -> f64 { + self.init_noise_sigma + } +} diff --git a/candle-transformers/src/models/stable_diffusion/ddpm.rs b/candle-transformers/src/models/stable_diffusion/ddpm.rs new file mode 100644 index 00000000..d393f39a --- /dev/null +++ b/candle-transformers/src/models/stable_diffusion/ddpm.rs @@ -0,0 +1,205 @@ +use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType}; +use candle::{Result, Tensor}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum DDPMVarianceType { + FixedSmall, + FixedSmallLog, + FixedLarge, + FixedLargeLog, + Learned, +} + +impl Default for DDPMVarianceType { + fn default() -> Self { + Self::FixedSmall + } +} + +#[derive(Debug, Clone)] +pub struct DDPMSchedulerConfig { + /// The value of beta at the beginning of training. + pub beta_start: f64, + /// The value of beta at the end of training. + pub beta_end: f64, + /// How beta evolved during training. + pub beta_schedule: BetaSchedule, + /// Option to predicted sample between -1 and 1 for numerical stability. + pub clip_sample: bool, + /// Option to clip the variance used when adding noise to the denoised sample. + pub variance_type: DDPMVarianceType, + /// prediction type of the scheduler function + pub prediction_type: PredictionType, + /// number of diffusion steps used to train the model. + pub train_timesteps: usize, +} + +impl Default for DDPMSchedulerConfig { + fn default() -> Self { + Self { + beta_start: 0.00085, + beta_end: 0.012, + beta_schedule: BetaSchedule::ScaledLinear, + clip_sample: false, + variance_type: DDPMVarianceType::FixedSmall, + prediction_type: PredictionType::Epsilon, + train_timesteps: 1000, + } + } +} + +pub struct DDPMScheduler { + alphas_cumprod: Vec<f64>, + init_noise_sigma: f64, + timesteps: Vec<usize>, + step_ratio: usize, + pub config: DDPMSchedulerConfig, +} + +impl DDPMScheduler { + pub fn new(inference_steps: usize, config: DDPMSchedulerConfig) -> Result<Self> { + let betas = match config.beta_schedule { + BetaSchedule::ScaledLinear => super::utils::linspace( + config.beta_start.sqrt(), + config.beta_end.sqrt(), + config.train_timesteps, + )? + .sqr()?, + BetaSchedule::Linear => { + super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)? + } + BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?, + }; + + let betas = betas.to_vec1::<f64>()?; + let mut alphas_cumprod = Vec::with_capacity(betas.len()); + for &beta in betas.iter() { + let alpha = 1.0 - beta; + alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64)) + } + + // min(train_timesteps, inference_steps) + // https://github.com/huggingface/diffusers/blob/8331da46837be40f96fbd24de6a6fb2da28acd11/src/diffusers/schedulers/scheduling_ddpm.py#L187 + let inference_steps = inference_steps.min(config.train_timesteps); + // arange the number of the scheduler's timesteps + let step_ratio = config.train_timesteps / inference_steps; + let timesteps: Vec<usize> = (0..inference_steps).map(|s| s * step_ratio).rev().collect(); + + Ok(Self { + alphas_cumprod, + init_noise_sigma: 1.0, + timesteps, + step_ratio, + config, + }) + } + + fn get_variance(&self, timestep: usize) -> f64 { + let prev_t = timestep as isize - self.step_ratio as isize; + let alpha_prod_t = self.alphas_cumprod[timestep]; + let alpha_prod_t_prev = if prev_t >= 0 { + self.alphas_cumprod[prev_t as usize] + } else { + 1.0 + }; + let current_beta_t = 1. - alpha_prod_t / alpha_prod_t_prev; + + // For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + // and sample from it to get previous sample + // x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + let variance = (1. - alpha_prod_t_prev) / (1. - alpha_prod_t) * current_beta_t; + + // retrieve variance + match self.config.variance_type { + DDPMVarianceType::FixedSmall => variance.max(1e-20), + // for rl-diffuser https://arxiv.org/abs/2205.09991 + DDPMVarianceType::FixedSmallLog => { + let variance = variance.max(1e-20).ln(); + (variance * 0.5).exp() + } + DDPMVarianceType::FixedLarge => current_beta_t, + DDPMVarianceType::FixedLargeLog => current_beta_t.ln(), + DDPMVarianceType::Learned => variance, + } + } + + pub fn timesteps(&self) -> &[usize] { + self.timesteps.as_slice() + } + + /// Ensures interchangeability with schedulers that need to scale the denoising model input + /// depending on the current timestep. + pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor { + sample + } + + pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> { + let prev_t = timestep as isize - self.step_ratio as isize; + + // https://github.com/huggingface/diffusers/blob/df2b548e893ccb8a888467c2508756680df22821/src/diffusers/schedulers/scheduling_ddpm.py#L272 + // 1. compute alphas, betas + let alpha_prod_t = self.alphas_cumprod[timestep]; + let alpha_prod_t_prev = if prev_t >= 0 { + self.alphas_cumprod[prev_t as usize] + } else { + 1.0 + }; + let beta_prod_t = 1. - alpha_prod_t; + let beta_prod_t_prev = 1. - alpha_prod_t_prev; + let current_alpha_t = alpha_prod_t / alpha_prod_t_prev; + let current_beta_t = 1. - current_alpha_t; + + // 2. compute predicted original sample from predicted noise also called "predicted x_0" of formula (15) + let mut pred_original_sample = match self.config.prediction_type { + PredictionType::Epsilon => { + ((sample - model_output * beta_prod_t.sqrt())? / alpha_prod_t.sqrt())? + } + PredictionType::Sample => model_output.clone(), + PredictionType::VPrediction => { + ((sample * alpha_prod_t.sqrt())? - model_output * beta_prod_t.sqrt())? + } + }; + + // 3. clip predicted x_0 + if self.config.clip_sample { + pred_original_sample = pred_original_sample.clamp(-1f32, 1f32)?; + } + + // 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + // See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + let pred_original_sample_coeff = (alpha_prod_t_prev.sqrt() * current_beta_t) / beta_prod_t; + let current_sample_coeff = current_alpha_t.sqrt() * beta_prod_t_prev / beta_prod_t; + + // 5. Compute predicted previous sample µ_t + // See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + let pred_prev_sample = ((&pred_original_sample * pred_original_sample_coeff)? + + sample * current_sample_coeff)?; + + // https://github.com/huggingface/diffusers/blob/df2b548e893ccb8a888467c2508756680df22821/src/diffusers/schedulers/scheduling_ddpm.py#L305 + // 6. Add noise + let mut variance = model_output.zeros_like()?; + if timestep > 0 { + let variance_noise = model_output.randn_like(0., 1.)?; + if self.config.variance_type == DDPMVarianceType::FixedSmallLog { + variance = (variance_noise * self.get_variance(timestep))?; + } else { + variance = (variance_noise * self.get_variance(timestep).sqrt())?; + } + } + &pred_prev_sample + variance + } + + pub fn add_noise( + &self, + original_samples: &Tensor, + noise: Tensor, + timestep: usize, + ) -> Result<Tensor> { + (original_samples * self.alphas_cumprod[timestep].sqrt())? + + noise * (1. - self.alphas_cumprod[timestep]).sqrt() + } + + pub fn init_noise_sigma(&self) -> f64 { + self.init_noise_sigma + } +} diff --git a/candle-transformers/src/models/stable_diffusion/embeddings.rs b/candle-transformers/src/models/stable_diffusion/embeddings.rs new file mode 100644 index 00000000..0de5f9a7 --- /dev/null +++ b/candle-transformers/src/models/stable_diffusion/embeddings.rs @@ -0,0 +1,65 @@ +use candle::{Result, Tensor, D}; +use candle_nn as nn; +use candle_nn::Module; + +#[derive(Debug)] +pub struct TimestepEmbedding { + linear_1: nn::Linear, + linear_2: nn::Linear, +} + +impl TimestepEmbedding { + // act_fn: "silu" + pub fn new(vs: nn::VarBuilder, channel: usize, time_embed_dim: usize) -> Result<Self> { + let linear_1 = nn::linear(channel, time_embed_dim, vs.pp("linear_1"))?; + let linear_2 = nn::linear(time_embed_dim, time_embed_dim, vs.pp("linear_2"))?; + Ok(Self { linear_1, linear_2 }) + } +} + +impl Module for TimestepEmbedding { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?; + self.linear_2.forward(&xs) + } +} + +#[derive(Debug)] +pub struct Timesteps { + num_channels: usize, + flip_sin_to_cos: bool, + downscale_freq_shift: f64, +} + +impl Timesteps { + pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self { + Self { + num_channels, + flip_sin_to_cos, + downscale_freq_shift, + } + } +} + +impl Module for Timesteps { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let half_dim = (self.num_channels / 2) as u32; + let exponent = (Tensor::arange(0, half_dim, xs.device())?.to_dtype(candle::DType::F32)? + * -f64::ln(10000.))?; + let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?; + let emb = exponent.exp()?.to_dtype(xs.dtype())?; + // emb = timesteps[:, None].float() * emb[None, :] + let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?; + let (cos, sin) = (emb.cos()?, emb.sin()?); + let emb = if self.flip_sin_to_cos { + Tensor::cat(&[&cos, &sin], D::Minus1)? + } else { + Tensor::cat(&[&sin, &cos], D::Minus1)? + }; + if self.num_channels % 2 == 1 { + emb.pad_with_zeros(D::Minus2, 0, 1) + } else { + Ok(emb) + } + } +} diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs new file mode 100644 index 00000000..c6f1b904 --- /dev/null +++ b/candle-transformers/src/models/stable_diffusion/mod.rs @@ -0,0 +1,303 @@ +pub mod attention; +pub mod clip; +pub mod ddim; +pub mod ddpm; +pub mod embeddings; +pub mod resnet; +pub mod schedulers; +pub mod unet_2d; +pub mod unet_2d_blocks; +pub mod utils; +pub mod vae; + +use candle::{DType, Device, Result}; +use candle_nn as nn; + +#[derive(Clone, Debug)] +pub struct StableDiffusionConfig { + pub width: usize, + pub height: usize, + pub clip: clip::Config, + pub clip2: Option<clip::Config>, + autoencoder: vae::AutoEncoderKLConfig, + unet: unet_2d::UNet2DConditionModelConfig, + scheduler: ddim::DDIMSchedulerConfig, +} + +impl StableDiffusionConfig { + pub fn v1_5( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + ) -> Self { + let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig { + out_channels, + use_cross_attn, + attention_head_dim, + }; + // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json + let unet = unet_2d::UNet2DConditionModelConfig { + blocks: vec![ + bc(320, Some(1), 8), + bc(640, Some(1), 8), + bc(1280, Some(1), 8), + bc(1280, None, 8), + ], + center_input_sample: false, + cross_attention_dim: 768, + downsample_padding: 1, + flip_sin_to_cos: true, + freq_shift: 0., + layers_per_block: 2, + mid_block_scale_factor: 1., + norm_eps: 1e-5, + norm_num_groups: 32, + sliced_attention_size, + use_linear_projection: false, + }; + let autoencoder = vae::AutoEncoderKLConfig { + block_out_channels: vec![128, 256, 512, 512], + layers_per_block: 2, + latent_channels: 4, + norm_num_groups: 32, + }; + let height = if let Some(height) = height { + assert_eq!(height % 8, 0, "height has to be divisible by 8"); + height + } else { + 512 + }; + + let width = if let Some(width) = width { + assert_eq!(width % 8, 0, "width has to be divisible by 8"); + width + } else { + 512 + }; + + Self { + width, + height, + clip: clip::Config::v1_5(), + clip2: None, + autoencoder, + scheduler: Default::default(), + unet, + } + } + + fn v2_1_( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + prediction_type: schedulers::PredictionType, + ) -> Self { + let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig { + out_channels, + use_cross_attn, + attention_head_dim, + }; + // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json + let unet = unet_2d::UNet2DConditionModelConfig { + blocks: vec![ + bc(320, Some(1), 5), + bc(640, Some(1), 10), + bc(1280, Some(1), 20), + bc(1280, None, 20), + ], + center_input_sample: false, + cross_attention_dim: 1024, + downsample_padding: 1, + flip_sin_to_cos: true, + freq_shift: 0., + layers_per_block: 2, + mid_block_scale_factor: 1., + norm_eps: 1e-5, + norm_num_groups: 32, + sliced_attention_size, + use_linear_projection: true, + }; + // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/vae/config.json + let autoencoder = vae::AutoEncoderKLConfig { + block_out_channels: vec![128, 256, 512, 512], + layers_per_block: 2, + latent_channels: 4, + norm_num_groups: 32, + }; + let scheduler = ddim::DDIMSchedulerConfig { + prediction_type, + ..Default::default() + }; + + let height = if let Some(height) = height { + assert_eq!(height % 8, 0, "height has to be divisible by 8"); + height + } else { + 768 + }; + + let width = if let Some(width) = width { + assert_eq!(width % 8, 0, "width has to be divisible by 8"); + width + } else { + 768 + }; + + Self { + width, + height, + clip: clip::Config::v2_1(), + clip2: None, + autoencoder, + scheduler, + unet, + } + } + + pub fn v2_1( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + ) -> Self { + // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/scheduler/scheduler_config.json + Self::v2_1_( + sliced_attention_size, + height, + width, + schedulers::PredictionType::VPrediction, + ) + } + + fn sdxl_( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + prediction_type: schedulers::PredictionType, + ) -> Self { + let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig { + out_channels, + use_cross_attn, + attention_head_dim, + }; + // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json + let unet = unet_2d::UNet2DConditionModelConfig { + blocks: vec![ + bc(320, None, 5), + bc(640, Some(2), 10), + bc(1280, Some(10), 20), + ], + center_input_sample: false, + cross_attention_dim: 2048, + downsample_padding: 1, + flip_sin_to_cos: true, + freq_shift: 0., + layers_per_block: 2, + mid_block_scale_factor: 1., + norm_eps: 1e-5, + norm_num_groups: 32, + sliced_attention_size, + use_linear_projection: true, + }; + // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json + let autoencoder = vae::AutoEncoderKLConfig { + block_out_channels: vec![128, 256, 512, 512], + layers_per_block: 2, + latent_channels: 4, + norm_num_groups: 32, + }; + let scheduler = ddim::DDIMSchedulerConfig { + prediction_type, + ..Default::default() + }; + + let height = if let Some(height) = height { + assert_eq!(height % 8, 0, "height has to be divisible by 8"); + height + } else { + 1024 + }; + + let width = if let Some(width) = width { + assert_eq!(width % 8, 0, "width has to be divisible by 8"); + width + } else { + 1024 + }; + + Self { + width, + height, + clip: clip::Config::sdxl(), + clip2: Some(clip::Config::sdxl2()), + autoencoder, + scheduler, + unet, + } + } + + pub fn sdxl( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + ) -> Self { + Self::sdxl_( + sliced_attention_size, + height, + width, + // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json + schedulers::PredictionType::Epsilon, + ) + } + + pub fn build_vae<P: AsRef<std::path::Path>>( + &self, + vae_weights: P, + device: &Device, + dtype: DType, + ) -> Result<vae::AutoEncoderKL> { + let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? }; + let weights = weights.deserialize()?; + let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], dtype, device); + // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json + let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?; + Ok(autoencoder) + } + + pub fn build_unet<P: AsRef<std::path::Path>>( + &self, + unet_weights: P, + device: &Device, + in_channels: usize, + use_flash_attn: bool, + dtype: DType, + ) -> Result<unet_2d::UNet2DConditionModel> { + let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? }; + let weights = weights.deserialize()?; + let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], dtype, device); + let unet = unet_2d::UNet2DConditionModel::new( + vs_unet, + in_channels, + 4, + use_flash_attn, + self.unet.clone(), + )?; + Ok(unet) + } + + pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> { + ddim::DDIMScheduler::new(n_steps, self.scheduler) + } +} + +pub fn build_clip_transformer<P: AsRef<std::path::Path>>( + clip: &clip::Config, + clip_weights: P, + device: &Device, + dtype: DType, +) -> Result<clip::ClipTextTransformer> { + let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? }; + let weights = weights.deserialize()?; + let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device); + let text_model = clip::ClipTextTransformer::new(vs, clip)?; + Ok(text_model) +} diff --git a/candle-transformers/src/models/stable_diffusion/resnet.rs b/candle-transformers/src/models/stable_diffusion/resnet.rs new file mode 100644 index 00000000..0d818115 --- /dev/null +++ b/candle-transformers/src/models/stable_diffusion/resnet.rs @@ -0,0 +1,138 @@ +//! ResNet Building Blocks +//! +//! Some Residual Network blocks used in UNet models. +//! +//! Denoising Diffusion Implicit Models, K. He and al, 2015. +//! https://arxiv.org/abs/1512.03385 +use super::utils::{conv2d, Conv2d}; +use candle::{Result, Tensor, D}; +use candle_nn as nn; +use candle_nn::Module; + +/// Configuration for a ResNet block. +#[derive(Debug, Clone, Copy)] +pub struct ResnetBlock2DConfig { + /// The number of output channels, defaults to the number of input channels. + pub out_channels: Option<usize>, + pub temb_channels: Option<usize>, + /// The number of groups to use in group normalization. + pub groups: usize, + pub groups_out: Option<usize>, + /// The epsilon to be used in the group normalization operations. + pub eps: f64, + /// Whether to use a 2D convolution in the skip connection. When using None, + /// such a convolution is used if the number of input channels is different from + /// the number of output channels. + pub use_in_shortcut: Option<bool>, + // non_linearity: silu + /// The final output is scaled by dividing by this value. + pub output_scale_factor: f64, +} + +impl Default for ResnetBlock2DConfig { + fn default() -> Self { + Self { + out_channels: None, + temb_channels: Some(512), + groups: 32, + groups_out: None, + eps: 1e-6, + use_in_shortcut: None, + output_scale_factor: 1., + } + } +} + +#[derive(Debug)] +pub struct ResnetBlock2D { + norm1: nn::GroupNorm, + conv1: Conv2d, + norm2: nn::GroupNorm, + conv2: Conv2d, + time_emb_proj: Option<nn::Linear>, + conv_shortcut: Option<Conv2d>, + span: tracing::Span, + config: ResnetBlock2DConfig, +} + +impl ResnetBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + config: ResnetBlock2DConfig, + ) -> Result<Self> { + let out_channels = config.out_channels.unwrap_or(in_channels); + let conv_cfg = nn::Conv2dConfig { + stride: 1, + padding: 1, + groups: 1, + dilation: 1, + }; + let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?; + let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?; + let groups_out = config.groups_out.unwrap_or(config.groups); + let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?; + let conv2 = conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?; + let use_in_shortcut = config + .use_in_shortcut + .unwrap_or(in_channels != out_channels); + let conv_shortcut = if use_in_shortcut { + let conv_cfg = nn::Conv2dConfig { + stride: 1, + padding: 0, + groups: 1, + dilation: 1, + }; + Some(conv2d( + in_channels, + out_channels, + 1, + conv_cfg, + vs.pp("conv_shortcut"), + )?) + } else { + None + }; + let time_emb_proj = match config.temb_channels { + None => None, + Some(temb_channels) => Some(nn::linear( + temb_channels, + out_channels, + vs.pp("time_emb_proj"), + )?), + }; + let span = tracing::span!(tracing::Level::TRACE, "resnet2d"); + Ok(Self { + norm1, + conv1, + norm2, + conv2, + time_emb_proj, + span, + config, + conv_shortcut, + }) + } + + pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> { + let _enter = self.span.enter(); + let shortcut_xs = match &self.conv_shortcut { + Some(conv_shortcut) => conv_shortcut.forward(xs)?, + None => xs.clone(), + }; + let xs = self.norm1.forward(xs)?; + let xs = self.conv1.forward(&nn::ops::silu(&xs)?)?; + let xs = match (temb, &self.time_emb_proj) { + (Some(temb), Some(time_emb_proj)) => time_emb_proj + .forward(&nn::ops::silu(temb)?)? + .unsqueeze(D::Minus1)? + .unsqueeze(D::Minus1)? + .broadcast_add(&xs)?, + _ => xs, + }; + let xs = self + .conv2 + .forward(&nn::ops::silu(&self.norm2.forward(&xs)?)?)?; + (shortcut_xs + xs)? / self.config.output_scale_factor + } +} diff --git a/candle-transformers/src/models/stable_diffusion/schedulers.rs b/candle-transformers/src/models/stable_diffusion/schedulers.rs new file mode 100644 index 00000000..3f6a1d72 --- /dev/null +++ b/candle-transformers/src/models/stable_diffusion/schedulers.rs @@ -0,0 +1,45 @@ +#![allow(dead_code)] +//! # Diffusion pipelines and models +//! +//! Noise schedulers can be used to set the trade-off between +//! inference speed and quality. + +use candle::{Result, Tensor}; + +/// This represents how beta ranges from its minimum value to the maximum +/// during training. +#[derive(Debug, Clone, Copy)] +pub enum BetaSchedule { + /// Linear interpolation. + Linear, + /// Linear interpolation of the square root of beta. + ScaledLinear, + /// Glide cosine schedule + SquaredcosCapV2, +} + +#[derive(Debug, Clone, Copy)] +pub enum PredictionType { + Epsilon, + VPrediction, + Sample, +} + +/// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of +/// `(1-beta)` over time from `t = [0,1]`. +/// +/// Contains a function `alpha_bar` that takes an argument `t` and transforms it to the cumulative product of `(1-beta)` +/// up to that part of the diffusion process. +pub(crate) fn betas_for_alpha_bar(num_diffusion_timesteps: usize, max_beta: f64) -> Result<Tensor> { + let alpha_bar = |time_step: usize| { + f64::cos((time_step as f64 + 0.008) / 1.008 * std::f64::consts::FRAC_PI_2).powi(2) + }; + let mut betas = Vec::with_capacity(num_diffusion_timesteps); + for i in 0..num_diffusion_timesteps { + let t1 = i / num_diffusion_timesteps; + let t2 = (i + 1) / num_diffusion_timesteps; + betas.push((1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta)); + } + let betas_len = betas.len(); + Tensor::from_vec(betas, betas_len, &candle::Device::Cpu) +} diff --git a/candle-transformers/src/models/stable_diffusion/unet_2d.rs b/candle-transformers/src/models/stable_diffusion/unet_2d.rs new file mode 100644 index 00000000..a3ed136e --- /dev/null +++ b/candle-transformers/src/models/stable_diffusion/unet_2d.rs @@ -0,0 +1,401 @@ +//! 2D UNet Denoising Models +//! +//! The 2D Unet models take as input a noisy sample and the current diffusion +//! timestep and return a denoised version of the input. +use super::embeddings::{TimestepEmbedding, Timesteps}; +use super::unet_2d_blocks::*; +use super::utils::{conv2d, Conv2d}; +use candle::{Result, Tensor}; +use candle_nn as nn; +use candle_nn::Module; + +#[derive(Debug, Clone, Copy)] +pub struct BlockConfig { + pub out_channels: usize, + /// When `None` no cross-attn is used, when `Some(d)` then cross-attn is used and `d` is the + /// number of transformer blocks to be used. + pub use_cross_attn: Option<usize>, + pub attention_head_dim: usize, +} + +#[derive(Debug, Clone)] +pub struct UNet2DConditionModelConfig { + pub center_input_sample: bool, + pub flip_sin_to_cos: bool, + pub freq_shift: f64, + pub blocks: Vec<BlockConfig>, + pub layers_per_block: usize, + pub downsample_padding: usize, + pub mid_block_scale_factor: f64, + pub norm_num_groups: usize, + pub norm_eps: f64, + pub cross_attention_dim: usize, + pub sliced_attention_size: Option<usize>, + pub use_linear_projection: bool, +} + +impl Default for UNet2DConditionModelConfig { + fn default() -> Self { + Self { + center_input_sample: false, + flip_sin_to_cos: true, + freq_shift: 0., + blocks: vec![ + BlockConfig { + out_channels: 320, + use_cross_attn: Some(1), + attention_head_dim: 8, + }, + BlockConfig { + out_channels: 640, + use_cross_attn: Some(1), + attention_head_dim: 8, + }, + BlockConfig { + out_channels: 1280, + use_cross_attn: Some(1), + attention_head_dim: 8, + }, + BlockConfig { + out_channels: 1280, + use_cross_attn: None, + attention_head_dim: 8, + }, + ], + layers_per_block: 2, + downsample_padding: 1, + mid_block_scale_factor: 1., + norm_num_groups: 32, + norm_eps: 1e-5, + cross_attention_dim: 1280, + sliced_attention_size: None, + use_linear_projection: false, + } + } +} + +#[derive(Debug)] +pub(crate) enum UNetDownBlock { + Basic(DownBlock2D), + CrossAttn(CrossAttnDownBlock2D), +} + +#[derive(Debug)] +enum UNetUpBlock { + Basic(UpBlock2D), + CrossAttn(CrossAttnUpBlock2D), +} + +#[derive(Debug)] +pub struct UNet2DConditionModel { + conv_in: Conv2d, + time_proj: Timesteps, + time_embedding: TimestepEmbedding, + down_blocks: Vec<UNetDownBlock>, + mid_block: UNetMidBlock2DCrossAttn, + up_blocks: Vec<UNetUpBlock>, + conv_norm_out: nn::GroupNorm, + conv_out: Conv2d, + span: tracing::Span, + config: UNet2DConditionModelConfig, +} + +impl UNet2DConditionModel { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + use_flash_attn: bool, + config: UNet2DConditionModelConfig, + ) -> Result<Self> { + let n_blocks = config.blocks.len(); + let b_channels = config.blocks[0].out_channels; + let bl_channels = config.blocks.last().unwrap().out_channels; + let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim; + let time_embed_dim = b_channels * 4; + let conv_cfg = nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let conv_in = conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?; + + let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift); + let time_embedding = + TimestepEmbedding::new(vs.pp("time_embedding"), b_channels, time_embed_dim)?; + + let vs_db = vs.pp("down_blocks"); + let down_blocks = (0..n_blocks) + .map(|i| { + let BlockConfig { + out_channels, + use_cross_attn, + attention_head_dim, + } = config.blocks[i]; + + // Enable automatic attention slicing if the config sliced_attention_size is set to 0. + let sliced_attention_size = match config.sliced_attention_size { + Some(0) => Some(attention_head_dim / 2), + _ => config.sliced_attention_size, + }; + + let in_channels = if i > 0 { + config.blocks[i - 1].out_channels + } else { + b_channels + }; + let db_cfg = DownBlock2DConfig { + num_layers: config.layers_per_block, + resnet_eps: config.norm_eps, + resnet_groups: config.norm_num_groups, + add_downsample: i < n_blocks - 1, + downsample_padding: config.downsample_padding, + ..Default::default() + }; + if let Some(transformer_layers_per_block) = use_cross_attn { + let config = CrossAttnDownBlock2DConfig { + downblock: db_cfg, + attn_num_head_channels: attention_head_dim, + cross_attention_dim: config.cross_attention_dim, + sliced_attention_size, + use_linear_projection: config.use_linear_projection, + transformer_layers_per_block, + }; + let block = CrossAttnDownBlock2D::new( + vs_db.pp(&i.to_string()), + in_channels, + out_channels, + Some(time_embed_dim), + use_flash_attn, + config, + )?; + Ok(UNetDownBlock::CrossAttn(block)) + } else { + let block = DownBlock2D::new( + vs_db.pp(&i.to_string()), + in_channels, + out_channels, + Some(time_embed_dim), + db_cfg, + )?; + Ok(UNetDownBlock::Basic(block)) + } + }) + .collect::<Result<Vec<_>>>()?; + + // https://github.com/huggingface/diffusers/blob/a76f2ad538e73b34d5fe7be08c8eb8ab38c7e90c/src/diffusers/models/unet_2d_condition.py#L462 + let mid_transformer_layers_per_block = match config.blocks.last() { + None => 1, + Some(block) => block.use_cross_attn.unwrap_or(1), + }; + let mid_cfg = UNetMidBlock2DCrossAttnConfig { + resnet_eps: config.norm_eps, + output_scale_factor: config.mid_block_scale_factor, + cross_attn_dim: config.cross_attention_dim, + attn_num_head_channels: bl_attention_head_dim, + resnet_groups: Some(config.norm_num_groups), + use_linear_projection: config.use_linear_projection, + transformer_layers_per_block: mid_transformer_layers_per_block, + ..Default::default() + }; + + let mid_block = UNetMidBlock2DCrossAttn::new( + vs.pp("mid_block"), + bl_channels, + Some(time_embed_dim), + use_flash_attn, + mid_cfg, + )?; + + let vs_ub = vs.pp("up_blocks"); + let up_blocks = (0..n_blocks) + .map(|i| { + let BlockConfig { + out_channels, + use_cross_attn, + attention_head_dim, + } = config.blocks[n_blocks - 1 - i]; + + // Enable automatic attention slicing if the config sliced_attention_size is set to 0. + let sliced_attention_size = match config.sliced_attention_size { + Some(0) => Some(attention_head_dim / 2), + _ => config.sliced_attention_size, + }; + + let prev_out_channels = if i > 0 { + config.blocks[n_blocks - i].out_channels + } else { + bl_channels + }; + let in_channels = { + let index = if i == n_blocks - 1 { + 0 + } else { + n_blocks - i - 2 + }; + config.blocks[index].out_channels + }; + let ub_cfg = UpBlock2DConfig { + num_layers: config.layers_per_block + 1, + resnet_eps: config.norm_eps, + resnet_groups: config.norm_num_groups, + add_upsample: i < n_blocks - 1, + ..Default::default() + }; + if let Some(transformer_layers_per_block) = use_cross_attn { + let config = CrossAttnUpBlock2DConfig { + upblock: ub_cfg, + attn_num_head_channels: attention_head_dim, + cross_attention_dim: config.cross_attention_dim, + sliced_attention_size, + use_linear_projection: config.use_linear_projection, + transformer_layers_per_block, + }; + let block = CrossAttnUpBlock2D::new( + vs_ub.pp(&i.to_string()), + in_channels, + prev_out_channels, + out_channels, + Some(time_embed_dim), + use_flash_attn, + config, + )?; + Ok(UNetUpBlock::CrossAttn(block)) + } else { + let block = UpBlock2D::new( + vs_ub.pp(&i.to_string()), + in_channels, + prev_out_channels, + out_channels, + Some(time_embed_dim), + ub_cfg, + )?; + Ok(UNetUpBlock::Basic(block)) + } + }) + .collect::<Result<Vec<_>>>()?; + + let conv_norm_out = nn::group_norm( + config.norm_num_groups, + b_channels, + config.norm_eps, + vs.pp("conv_norm_out"), + )?; + let conv_out = conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?; + let span = tracing::span!(tracing::Level::TRACE, "unet2d"); + Ok(Self { + conv_in, + time_proj, + time_embedding, + down_blocks, + mid_block, + up_blocks, + conv_norm_out, + conv_out, + span, + config, + }) + } + + pub fn forward( + &self, + xs: &Tensor, + timestep: f64, + encoder_hidden_states: &Tensor, + ) -> Result<Tensor> { + let _enter = self.span.enter(); + self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None) + } + + pub fn forward_with_additional_residuals( + &self, + xs: &Tensor, + timestep: f64, + encoder_hidden_states: &Tensor, + down_block_additional_residuals: Option<&[Tensor]>, + mid_block_additional_residual: Option<&Tensor>, + ) -> Result<Tensor> { + let (bsize, _channels, height, width) = xs.dims4()?; + let device = xs.device(); + let n_blocks = self.config.blocks.len(); + let num_upsamplers = n_blocks - 1; + let default_overall_up_factor = 2usize.pow(num_upsamplers as u32); + let forward_upsample_size = + height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0; + // 0. center input if necessary + let xs = if self.config.center_input_sample { + ((xs * 2.0)? - 1.0)? + } else { + xs.clone() + }; + // 1. time + let emb = (Tensor::ones(bsize, xs.dtype(), device)? * timestep)?; + let emb = self.time_proj.forward(&emb)?; + let emb = self.time_embedding.forward(&emb)?; + // 2. pre-process + let xs = self.conv_in.forward(&xs)?; + // 3. down + let mut down_block_res_xs = vec![xs.clone()]; + let mut xs = xs; + for down_block in self.down_blocks.iter() { + let (_xs, res_xs) = match down_block { + UNetDownBlock::Basic(b) => b.forward(&xs, Some(&emb))?, + UNetDownBlock::CrossAttn(b) => { + b.forward(&xs, Some(&emb), Some(encoder_hidden_states))? + } + }; + down_block_res_xs.extend(res_xs); + xs = _xs; + } + + let new_down_block_res_xs = + if let Some(down_block_additional_residuals) = down_block_additional_residuals { + let mut v = vec![]; + // A previous version of this code had a bug because of the addition being made + // in place via += hence modifying the input of the mid block. + for (i, residuals) in down_block_additional_residuals.iter().enumerate() { + v.push((&down_block_res_xs[i] + residuals)?) + } + v + } else { + down_block_res_xs + }; + let mut down_block_res_xs = new_down_block_res_xs; + + // 4. mid + let xs = self + .mid_block + .forward(&xs, Some(&emb), Some(encoder_hidden_states))?; + let xs = match mid_block_additional_residual { + None => xs, + Some(m) => (m + xs)?, + }; + // 5. up + let mut xs = xs; + let mut upsample_size = None; + for (i, up_block) in self.up_blocks.iter().enumerate() { + let n_resnets = match up_block { + UNetUpBlock::Basic(b) => b.resnets.len(), + UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(), + }; + let res_xs = down_block_res_xs.split_off(down_block_res_xs.len() - n_resnets); + if i < n_blocks - 1 && forward_upsample_size { + let (_, _, h, w) = down_block_res_xs.last().unwrap().dims4()?; + upsample_size = Some((h, w)) + } + xs = match up_block { + UNetUpBlock::Basic(b) => b.forward(&xs, &res_xs, Some(&emb), upsample_size)?, + UNetUpBlock::CrossAttn(b) => b.forward( + &xs, + &res_xs, + Some(&emb), + upsample_size, + Some(encoder_hidden_states), + )?, + }; + } + // 6. post-process + let xs = self.conv_norm_out.forward(&xs)?; + let xs = nn::ops::silu(&xs)?; + self.conv_out.forward(&xs) + } +} diff --git a/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs new file mode 100644 index 00000000..29510cef --- /dev/null +++ b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs @@ -0,0 +1,868 @@ +//! 2D UNet Building Blocks +//! +use super::attention::{ + AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig, +}; +use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig}; +use super::utils::{conv2d, Conv2d}; +use candle::{Module, Result, Tensor, D}; +use candle_nn as nn; + +#[derive(Debug)] +struct Downsample2D { + conv: Option<Conv2d>, + padding: usize, + span: tracing::Span, +} + +impl Downsample2D { + fn new( + vs: nn::VarBuilder, + in_channels: usize, + use_conv: bool, + out_channels: usize, + padding: usize, + ) -> Result<Self> { + let conv = if use_conv { + let config = nn::Conv2dConfig { + stride: 2, + padding, + ..Default::default() + }; + let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?; + Some(conv) + } else { + None + }; + let span = tracing::span!(tracing::Level::TRACE, "downsample2d"); + Ok(Self { + conv, + padding, + span, + }) + } +} + +impl Module for Downsample2D { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + match &self.conv { + None => xs.avg_pool2d(2), + Some(conv) => { + if self.padding == 0 { + let xs = xs + .pad_with_zeros(D::Minus1, 0, 1)? + .pad_with_zeros(D::Minus2, 0, 1)?; + conv.forward(&xs) + } else { + conv.forward(xs) + } + } + } + } +} + +// This does not support the conv-transpose mode. +#[derive(Debug)] +struct Upsample2D { + conv: Conv2d, + span: tracing::Span, +} + +impl Upsample2D { + fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> { + let config = nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?; + let span = tracing::span!(tracing::Level::TRACE, "upsample2d"); + Ok(Self { conv, span }) + } +} + +impl Upsample2D { + fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> { + let _enter = self.span.enter(); + let xs = match size { + None => { + let (_bsize, _channels, h, w) = xs.dims4()?; + xs.upsample_nearest2d(2 * h, 2 * w)? + } + Some((h, w)) => xs.upsample_nearest2d(h, w)?, + }; + self.conv.forward(&xs) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct DownEncoderBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + pub resnet_groups: usize, + pub output_scale_factor: f64, + pub add_downsample: bool, + pub downsample_padding: usize, +} + +impl Default for DownEncoderBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: 32, + output_scale_factor: 1., + add_downsample: true, + downsample_padding: 1, + } + } +} + +#[derive(Debug)] +pub struct DownEncoderBlock2D { + resnets: Vec<ResnetBlock2D>, + downsampler: Option<Downsample2D>, + span: tracing::Span, + pub config: DownEncoderBlock2DConfig, +} + +impl DownEncoderBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + config: DownEncoderBlock2DConfig, + ) -> Result<Self> { + let resnets: Vec<_> = { + let vs = vs.pp("resnets"); + let conv_cfg = ResnetBlock2DConfig { + eps: config.resnet_eps, + out_channels: Some(out_channels), + groups: config.resnet_groups, + output_scale_factor: config.output_scale_factor, + temb_channels: None, + ..Default::default() + }; + (0..(config.num_layers)) + .map(|i| { + let in_channels = if i == 0 { in_channels } else { out_channels }; + ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg) + }) + .collect::<Result<Vec<_>>>()? + }; + let downsampler = if config.add_downsample { + let downsample = Downsample2D::new( + vs.pp("downsamplers").pp("0"), + out_channels, + true, + out_channels, + config.downsample_padding, + )?; + Some(downsample) + } else { + None + }; + let span = tracing::span!(tracing::Level::TRACE, "down-enc2d"); + Ok(Self { + resnets, + downsampler, + span, + config, + }) + } +} + +impl Module for DownEncoderBlock2D { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut xs = xs.clone(); + for resnet in self.resnets.iter() { + xs = resnet.forward(&xs, None)? + } + match &self.downsampler { + Some(downsampler) => downsampler.forward(&xs), + None => Ok(xs), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct UpDecoderBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + pub resnet_groups: usize, + pub output_scale_factor: f64, + pub add_upsample: bool, +} + +impl Default for UpDecoderBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: 32, + output_scale_factor: 1., + add_upsample: true, + } + } +} + +#[derive(Debug)] +pub struct UpDecoderBlock2D { + resnets: Vec<ResnetBlock2D>, + upsampler: Option<Upsample2D>, + span: tracing::Span, + pub config: UpDecoderBlock2DConfig, +} + +impl UpDecoderBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + config: UpDecoderBlock2DConfig, + ) -> Result<Self> { + let resnets: Vec<_> = { + let vs = vs.pp("resnets"); + let conv_cfg = ResnetBlock2DConfig { + out_channels: Some(out_channels), + eps: config.resnet_eps, + groups: config.resnet_groups, + output_scale_factor: config.output_scale_factor, + temb_channels: None, + ..Default::default() + }; + (0..(config.num_layers)) + .map(|i| { + let in_channels = if i == 0 { in_channels } else { out_channels }; + ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg) + }) + .collect::<Result<Vec<_>>>()? + }; + let upsampler = if config.add_upsample { + let upsample = + Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?; + Some(upsample) + } else { + None + }; + let span = tracing::span!(tracing::Level::TRACE, "up-dec2d"); + Ok(Self { + resnets, + upsampler, + span, + config, + }) + } +} + +impl Module for UpDecoderBlock2D { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut xs = xs.clone(); + for resnet in self.resnets.iter() { + xs = resnet.forward(&xs, None)? + } + match &self.upsampler { + Some(upsampler) => upsampler.forward(&xs, None), + None => Ok(xs), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct UNetMidBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + pub resnet_groups: Option<usize>, + pub attn_num_head_channels: Option<usize>, + // attention_type "default" + pub output_scale_factor: f64, +} + +impl Default for UNetMidBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: Some(32), + attn_num_head_channels: Some(1), + output_scale_factor: 1., + } + } +} + +#[derive(Debug)] +pub struct UNetMidBlock2D { + resnet: ResnetBlock2D, + attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>, + span: tracing::Span, + pub config: UNetMidBlock2DConfig, +} + +impl UNetMidBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + temb_channels: Option<usize>, + config: UNetMidBlock2DConfig, + ) -> Result<Self> { + let vs_resnets = vs.pp("resnets"); + let vs_attns = vs.pp("attentions"); + let resnet_groups = config + .resnet_groups + .unwrap_or_else(|| usize::min(in_channels / 4, 32)); + let resnet_cfg = ResnetBlock2DConfig { + eps: config.resnet_eps, + groups: resnet_groups, + output_scale_factor: config.output_scale_factor, + temb_channels, + ..Default::default() + }; + let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?; + let attn_cfg = AttentionBlockConfig { + num_head_channels: config.attn_num_head_channels, + num_groups: resnet_groups, + rescale_output_factor: config.output_scale_factor, + eps: config.resnet_eps, + }; + let mut attn_resnets = vec![]; + for index in 0..config.num_layers { + let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?; + let resnet = ResnetBlock2D::new( + vs_resnets.pp(&(index + 1).to_string()), + in_channels, + resnet_cfg, + )?; + attn_resnets.push((attn, resnet)) + } + let span = tracing::span!(tracing::Level::TRACE, "mid2d"); + Ok(Self { + resnet, + attn_resnets, + span, + config, + }) + } + + pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut xs = self.resnet.forward(xs, temb)?; + for (attn, resnet) in self.attn_resnets.iter() { + xs = resnet.forward(&attn.forward(&xs)?, temb)? + } + Ok(xs) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct UNetMidBlock2DCrossAttnConfig { + pub num_layers: usize, + pub resnet_eps: f64, + pub resnet_groups: Option<usize>, + pub attn_num_head_channels: usize, + // attention_type "default" + pub output_scale_factor: f64, + pub cross_attn_dim: usize, + pub sliced_attention_size: Option<usize>, + pub use_linear_projection: bool, + pub transformer_layers_per_block: usize, +} + +impl Default for UNetMidBlock2DCrossAttnConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: Some(32), + attn_num_head_channels: 1, + output_scale_factor: 1., + cross_attn_dim: 1280, + sliced_attention_size: None, // Sliced attention disabled + use_linear_projection: false, + transformer_layers_per_block: 1, + } + } +} + +#[derive(Debug)] +pub struct UNetMidBlock2DCrossAttn { + resnet: ResnetBlock2D, + attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>, + span: tracing::Span, + pub config: UNetMidBlock2DCrossAttnConfig, +} + +impl UNetMidBlock2DCrossAttn { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + temb_channels: Option<usize>, + use_flash_attn: bool, + config: UNetMidBlock2DCrossAttnConfig, + ) -> Result<Self> { + let vs_resnets = vs.pp("resnets"); + let vs_attns = vs.pp("attentions"); + let resnet_groups = config + .resnet_groups + .unwrap_or_else(|| usize::min(in_channels / 4, 32)); + let resnet_cfg = ResnetBlock2DConfig { + eps: config.resnet_eps, + groups: resnet_groups, + output_scale_factor: config.output_scale_factor, + temb_channels, + ..Default::default() + }; + let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?; + let n_heads = config.attn_num_head_channels; + let attn_cfg = SpatialTransformerConfig { + depth: config.transformer_layers_per_block, + num_groups: resnet_groups, + context_dim: Some(config.cross_attn_dim), + sliced_attention_size: config.sliced_attention_size, + use_linear_projection: config.use_linear_projection, + }; + let mut attn_resnets = vec![]; + for index in 0..config.num_layers { + let attn = SpatialTransformer::new( + vs_attns.pp(&index.to_string()), + in_channels, + n_heads, + in_channels / n_heads, + use_flash_attn, + attn_cfg, + )?; + let resnet = ResnetBlock2D::new( + vs_resnets.pp(&(index + 1).to_string()), + in_channels, + resnet_cfg, + )?; + attn_resnets.push((attn, resnet)) + } + let span = tracing::span!(tracing::Level::TRACE, "xa-mid2d"); + Ok(Self { + resnet, + attn_resnets, + span, + config, + }) + } + + pub fn forward( + &self, + xs: &Tensor, + temb: Option<&Tensor>, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut xs = self.resnet.forward(xs, temb)?; + for (attn, resnet) in self.attn_resnets.iter() { + xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)? + } + Ok(xs) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct DownBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + // resnet_time_scale_shift: "default" + // resnet_act_fn: "swish" + pub resnet_groups: usize, + pub output_scale_factor: f64, + pub add_downsample: bool, + pub downsample_padding: usize, +} + +impl Default for DownBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: 32, + output_scale_factor: 1., + add_downsample: true, + downsample_padding: 1, + } + } +} + +#[derive(Debug)] +pub struct DownBlock2D { + resnets: Vec<ResnetBlock2D>, + downsampler: Option<Downsample2D>, + span: tracing::Span, + pub config: DownBlock2DConfig, +} + +impl DownBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + temb_channels: Option<usize>, + config: DownBlock2DConfig, + ) -> Result<Self> { + let vs_resnets = vs.pp("resnets"); + let resnet_cfg = ResnetBlock2DConfig { + out_channels: Some(out_channels), + eps: config.resnet_eps, + output_scale_factor: config.output_scale_factor, + temb_channels, + ..Default::default() + }; + let resnets = (0..config.num_layers) + .map(|i| { + let in_channels = if i == 0 { in_channels } else { out_channels }; + ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg) + }) + .collect::<Result<Vec<_>>>()?; + let downsampler = if config.add_downsample { + let downsampler = Downsample2D::new( + vs.pp("downsamplers").pp("0"), + out_channels, + true, + out_channels, + config.downsample_padding, + )?; + Some(downsampler) + } else { + None + }; + let span = tracing::span!(tracing::Level::TRACE, "down2d"); + Ok(Self { + resnets, + downsampler, + span, + config, + }) + } + + pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> { + let _enter = self.span.enter(); + let mut xs = xs.clone(); + let mut output_states = vec![]; + for resnet in self.resnets.iter() { + xs = resnet.forward(&xs, temb)?; + output_states.push(xs.clone()); + } + let xs = match &self.downsampler { + Some(downsampler) => { + let xs = downsampler.forward(&xs)?; + output_states.push(xs.clone()); + xs + } + None => xs, + }; + Ok((xs, output_states)) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct CrossAttnDownBlock2DConfig { + pub downblock: DownBlock2DConfig, + pub attn_num_head_channels: usize, + pub cross_attention_dim: usize, + // attention_type: "default" + pub sliced_attention_size: Option<usize>, + pub use_linear_projection: bool, + pub transformer_layers_per_block: usize, +} + +impl Default for CrossAttnDownBlock2DConfig { + fn default() -> Self { + Self { + downblock: Default::default(), + attn_num_head_channels: 1, + cross_attention_dim: 1280, + sliced_attention_size: None, + use_linear_projection: false, + transformer_layers_per_block: 1, + } + } +} + +#[derive(Debug)] +pub struct CrossAttnDownBlock2D { + downblock: DownBlock2D, + attentions: Vec<SpatialTransformer>, + span: tracing::Span, + pub config: CrossAttnDownBlock2DConfig, +} + +impl CrossAttnDownBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + temb_channels: Option<usize>, + use_flash_attn: bool, + config: CrossAttnDownBlock2DConfig, + ) -> Result<Self> { + let downblock = DownBlock2D::new( + vs.clone(), + in_channels, + out_channels, + temb_channels, + config.downblock, + )?; + let n_heads = config.attn_num_head_channels; + let cfg = SpatialTransformerConfig { + depth: config.transformer_layers_per_block, + context_dim: Some(config.cross_attention_dim), + num_groups: config.downblock.resnet_groups, + sliced_attention_size: config.sliced_attention_size, + use_linear_projection: config.use_linear_projection, + }; + let vs_attn = vs.pp("attentions"); + let attentions = (0..config.downblock.num_layers) + .map(|i| { + SpatialTransformer::new( + vs_attn.pp(&i.to_string()), + out_channels, + n_heads, + out_channels / n_heads, + use_flash_attn, + cfg, + ) + }) + .collect::<Result<Vec<_>>>()?; + let span = tracing::span!(tracing::Level::TRACE, "xa-down2d"); + Ok(Self { + downblock, + attentions, + span, + config, + }) + } + + pub fn forward( + &self, + xs: &Tensor, + temb: Option<&Tensor>, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<(Tensor, Vec<Tensor>)> { + let _enter = self.span.enter(); + let mut output_states = vec![]; + let mut xs = xs.clone(); + for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) { + xs = resnet.forward(&xs, temb)?; + xs = attn.forward(&xs, encoder_hidden_states)?; + output_states.push(xs.clone()); + } + let xs = match &self.downblock.downsampler { + Some(downsampler) => { + let xs = downsampler.forward(&xs)?; + output_states.push(xs.clone()); + xs + } + None => xs, + }; + Ok((xs, output_states)) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct UpBlock2DConfig { + pub num_layers: usize, + pub resnet_eps: f64, + // resnet_time_scale_shift: "default" + // resnet_act_fn: "swish" + pub resnet_groups: usize, + pub output_scale_factor: f64, + pub add_upsample: bool, +} + +impl Default for UpBlock2DConfig { + fn default() -> Self { + Self { + num_layers: 1, + resnet_eps: 1e-6, + resnet_groups: 32, + output_scale_factor: 1., + add_upsample: true, + } + } +} + +#[derive(Debug)] +pub struct UpBlock2D { + pub resnets: Vec<ResnetBlock2D>, + upsampler: Option<Upsample2D>, + span: tracing::Span, + pub config: UpBlock2DConfig, +} + +impl UpBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + prev_output_channels: usize, + out_channels: usize, + temb_channels: Option<usize>, + config: UpBlock2DConfig, + ) -> Result<Self> { + let vs_resnets = vs.pp("resnets"); + let resnet_cfg = ResnetBlock2DConfig { + out_channels: Some(out_channels), + temb_channels, + eps: config.resnet_eps, + output_scale_factor: config.output_scale_factor, + ..Default::default() + }; + let resnets = (0..config.num_layers) + .map(|i| { + let res_skip_channels = if i == config.num_layers - 1 { + in_channels + } else { + out_channels + }; + let resnet_in_channels = if i == 0 { + prev_output_channels + } else { + out_channels + }; + let in_channels = resnet_in_channels + res_skip_channels; + ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg) + }) + .collect::<Result<Vec<_>>>()?; + let upsampler = if config.add_upsample { + let upsampler = + Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?; + Some(upsampler) + } else { + None + }; + let span = tracing::span!(tracing::Level::TRACE, "up2d"); + Ok(Self { + resnets, + upsampler, + span, + config, + }) + } + + pub fn forward( + &self, + xs: &Tensor, + res_xs: &[Tensor], + temb: Option<&Tensor>, + upsample_size: Option<(usize, usize)>, + ) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut xs = xs.clone(); + for (index, resnet) in self.resnets.iter().enumerate() { + xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?; + xs = xs.contiguous()?; + xs = resnet.forward(&xs, temb)?; + } + match &self.upsampler { + Some(upsampler) => upsampler.forward(&xs, upsample_size), + None => Ok(xs), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct CrossAttnUpBlock2DConfig { + pub upblock: UpBlock2DConfig, + pub attn_num_head_channels: usize, + pub cross_attention_dim: usize, + // attention_type: "default" + pub sliced_attention_size: Option<usize>, + pub use_linear_projection: bool, + pub transformer_layers_per_block: usize, +} + +impl Default for CrossAttnUpBlock2DConfig { + fn default() -> Self { + Self { + upblock: Default::default(), + attn_num_head_channels: 1, + cross_attention_dim: 1280, + sliced_attention_size: None, + use_linear_projection: false, + transformer_layers_per_block: 1, + } + } +} + +#[derive(Debug)] +pub struct CrossAttnUpBlock2D { + pub upblock: UpBlock2D, + pub attentions: Vec<SpatialTransformer>, + span: tracing::Span, + pub config: CrossAttnUpBlock2DConfig, +} + +impl CrossAttnUpBlock2D { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + prev_output_channels: usize, + out_channels: usize, + temb_channels: Option<usize>, + use_flash_attn: bool, + config: CrossAttnUpBlock2DConfig, + ) -> Result<Self> { + let upblock = UpBlock2D::new( + vs.clone(), + in_channels, + prev_output_channels, + out_channels, + temb_channels, + config.upblock, + )?; + let n_heads = config.attn_num_head_channels; + let cfg = SpatialTransformerConfig { + depth: config.transformer_layers_per_block, + context_dim: Some(config.cross_attention_dim), + num_groups: config.upblock.resnet_groups, + sliced_attention_size: config.sliced_attention_size, + use_linear_projection: config.use_linear_projection, + }; + let vs_attn = vs.pp("attentions"); + let attentions = (0..config.upblock.num_layers) + .map(|i| { + SpatialTransformer::new( + vs_attn.pp(&i.to_string()), + out_channels, + n_heads, + out_channels / n_heads, + use_flash_attn, + cfg, + ) + }) + .collect::<Result<Vec<_>>>()?; + let span = tracing::span!(tracing::Level::TRACE, "xa-up2d"); + Ok(Self { + upblock, + attentions, + span, + config, + }) + } + + pub fn forward( + &self, + xs: &Tensor, + res_xs: &[Tensor], + temb: Option<&Tensor>, + upsample_size: Option<(usize, usize)>, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<Tensor> { + let _enter = self.span.enter(); + let mut xs = xs.clone(); + for (index, resnet) in self.upblock.resnets.iter().enumerate() { + xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?; + xs = xs.contiguous()?; + xs = resnet.forward(&xs, temb)?; + xs = self.attentions[index].forward(&xs, encoder_hidden_states)?; + } + match &self.upblock.upsampler { + Some(upsampler) => upsampler.forward(&xs, upsample_size), + None => Ok(xs), + } + } +} diff --git a/candle-transformers/src/models/stable_diffusion/utils.rs b/candle-transformers/src/models/stable_diffusion/utils.rs new file mode 100644 index 00000000..c62f17af --- /dev/null +++ b/candle-transformers/src/models/stable_diffusion/utils.rs @@ -0,0 +1,39 @@ +use candle::{Device, Result, Tensor}; +use candle_nn::Module; + +pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> { + if steps < 1 { + candle::bail!("cannot use linspace with steps {steps} <= 1") + } + let delta = (stop - start) / (steps - 1) as f64; + let vs = (0..steps) + .map(|step| start + step as f64 * delta) + .collect::<Vec<_>>(); + Tensor::from_vec(vs, steps, &Device::Cpu) +} + +// Wrap the conv2d op to provide some tracing. +#[derive(Debug)] +pub struct Conv2d { + inner: candle_nn::Conv2d, + span: tracing::Span, +} + +impl Conv2d { + pub fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +pub fn conv2d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: candle_nn::Conv2dConfig, + vs: candle_nn::VarBuilder, +) -> Result<Conv2d> { + let span = tracing::span!(tracing::Level::TRACE, "conv2d"); + let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?; + Ok(Conv2d { inner, span }) +} diff --git a/candle-transformers/src/models/stable_diffusion/vae.rs b/candle-transformers/src/models/stable_diffusion/vae.rs new file mode 100644 index 00000000..21709afe --- /dev/null +++ b/candle-transformers/src/models/stable_diffusion/vae.rs @@ -0,0 +1,380 @@ +#![allow(dead_code)] +//! # Variational Auto-Encoder (VAE) Models. +//! +//! Auto-encoder models compress their input to a usually smaller latent space +//! before expanding it back to its original shape. This results in the latent values +//! compressing the original information. +use super::unet_2d_blocks::{ + DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig, + UpDecoderBlock2D, UpDecoderBlock2DConfig, +}; +use candle::{Result, Tensor}; +use candle_nn as nn; +use candle_nn::Module; + +#[derive(Debug, Clone)] +struct EncoderConfig { + // down_block_types: DownEncoderBlock2D + block_out_channels: Vec<usize>, + layers_per_block: usize, + norm_num_groups: usize, + double_z: bool, +} + +impl Default for EncoderConfig { + fn default() -> Self { + Self { + block_out_channels: vec![64], + layers_per_block: 2, + norm_num_groups: 32, + double_z: true, + } + } +} + +#[derive(Debug)] +struct Encoder { + conv_in: nn::Conv2d, + down_blocks: Vec<DownEncoderBlock2D>, + mid_block: UNetMidBlock2D, + conv_norm_out: nn::GroupNorm, + conv_out: nn::Conv2d, + #[allow(dead_code)] + config: EncoderConfig, +} + +impl Encoder { + fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + config: EncoderConfig, + ) -> Result<Self> { + let conv_cfg = nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let conv_in = nn::conv2d( + in_channels, + config.block_out_channels[0], + 3, + conv_cfg, + vs.pp("conv_in"), + )?; + let mut down_blocks = vec![]; + let vs_down_blocks = vs.pp("down_blocks"); + for index in 0..config.block_out_channels.len() { + let out_channels = config.block_out_channels[index]; + let in_channels = if index > 0 { + config.block_out_channels[index - 1] + } else { + config.block_out_channels[0] + }; + let is_final = index + 1 == config.block_out_channels.len(); + let cfg = DownEncoderBlock2DConfig { + num_layers: config.layers_per_block, + resnet_eps: 1e-6, + resnet_groups: config.norm_num_groups, + add_downsample: !is_final, + downsample_padding: 0, + ..Default::default() + }; + let down_block = DownEncoderBlock2D::new( + vs_down_blocks.pp(&index.to_string()), + in_channels, + out_channels, + cfg, + )?; + down_blocks.push(down_block) + } + let last_block_out_channels = *config.block_out_channels.last().unwrap(); + let mid_cfg = UNetMidBlock2DConfig { + resnet_eps: 1e-6, + output_scale_factor: 1., + attn_num_head_channels: None, + resnet_groups: Some(config.norm_num_groups), + ..Default::default() + }; + let mid_block = + UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?; + let conv_norm_out = nn::group_norm( + config.norm_num_groups, + last_block_out_channels, + 1e-6, + vs.pp("conv_norm_out"), + )?; + let conv_out_channels = if config.double_z { + 2 * out_channels + } else { + out_channels + }; + let conv_cfg = nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let conv_out = nn::conv2d( + last_block_out_channels, + conv_out_channels, + 3, + conv_cfg, + vs.pp("conv_out"), + )?; + Ok(Self { + conv_in, + down_blocks, + mid_block, + conv_norm_out, + conv_out, + config, + }) + } +} + +impl Encoder { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = xs.apply(&self.conv_in)?; + for down_block in self.down_blocks.iter() { + xs = xs.apply(down_block)? + } + let xs = self + .mid_block + .forward(&xs, None)? + .apply(&self.conv_norm_out)?; + nn::ops::silu(&xs)?.apply(&self.conv_out) + } +} + +#[derive(Debug, Clone)] +struct DecoderConfig { + // up_block_types: UpDecoderBlock2D + block_out_channels: Vec<usize>, + layers_per_block: usize, + norm_num_groups: usize, +} + +impl Default for DecoderConfig { + fn default() -> Self { + Self { + block_out_channels: vec![64], + layers_per_block: 2, + norm_num_groups: 32, + } + } +} + +#[derive(Debug)] +struct Decoder { + conv_in: nn::Conv2d, + up_blocks: Vec<UpDecoderBlock2D>, + mid_block: UNetMidBlock2D, + conv_norm_out: nn::GroupNorm, + conv_out: nn::Conv2d, + #[allow(dead_code)] + config: DecoderConfig, +} + +impl Decoder { + fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + config: DecoderConfig, + ) -> Result<Self> { + let n_block_out_channels = config.block_out_channels.len(); + let last_block_out_channels = *config.block_out_channels.last().unwrap(); + let conv_cfg = nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let conv_in = nn::conv2d( + in_channels, + last_block_out_channels, + 3, + conv_cfg, + vs.pp("conv_in"), + )?; + let mid_cfg = UNetMidBlock2DConfig { + resnet_eps: 1e-6, + output_scale_factor: 1., + attn_num_head_channels: None, + resnet_groups: Some(config.norm_num_groups), + ..Default::default() + }; + let mid_block = + UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?; + let mut up_blocks = vec![]; + let vs_up_blocks = vs.pp("up_blocks"); + let reversed_block_out_channels: Vec<_> = + config.block_out_channels.iter().copied().rev().collect(); + for index in 0..n_block_out_channels { + let out_channels = reversed_block_out_channels[index]; + let in_channels = if index > 0 { + reversed_block_out_channels[index - 1] + } else { + reversed_block_out_channels[0] + }; + let is_final = index + 1 == n_block_out_channels; + let cfg = UpDecoderBlock2DConfig { + num_layers: config.layers_per_block + 1, + resnet_eps: 1e-6, + resnet_groups: config.norm_num_groups, + add_upsample: !is_final, + ..Default::default() + }; + let up_block = UpDecoderBlock2D::new( + vs_up_blocks.pp(&index.to_string()), + in_channels, + out_channels, + cfg, + )?; + up_blocks.push(up_block) + } + let conv_norm_out = nn::group_norm( + config.norm_num_groups, + config.block_out_channels[0], + 1e-6, + vs.pp("conv_norm_out"), + )?; + let conv_cfg = nn::Conv2dConfig { + padding: 1, + ..Default::default() + }; + let conv_out = nn::conv2d( + config.block_out_channels[0], + out_channels, + 3, + conv_cfg, + vs.pp("conv_out"), + )?; + Ok(Self { + conv_in, + up_blocks, + mid_block, + conv_norm_out, + conv_out, + config, + }) + } +} + +impl Decoder { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = self.mid_block.forward(&self.conv_in.forward(xs)?, None)?; + for up_block in self.up_blocks.iter() { + xs = up_block.forward(&xs)? + } + let xs = self.conv_norm_out.forward(&xs)?; + let xs = nn::ops::silu(&xs)?; + self.conv_out.forward(&xs) + } +} + +#[derive(Debug, Clone)] +pub struct AutoEncoderKLConfig { + pub block_out_channels: Vec<usize>, + pub layers_per_block: usize, + pub latent_channels: usize, + pub norm_num_groups: usize, +} + +impl Default for AutoEncoderKLConfig { + fn default() -> Self { + Self { + block_out_channels: vec![64], + layers_per_block: 1, + latent_channels: 4, + norm_num_groups: 32, + } + } +} + +pub struct DiagonalGaussianDistribution { + mean: Tensor, + std: Tensor, +} + +impl DiagonalGaussianDistribution { + pub fn new(parameters: &Tensor) -> Result<Self> { + let mut parameters = parameters.chunk(2, 1)?.into_iter(); + let mean = parameters.next().unwrap(); + let logvar = parameters.next().unwrap(); + let std = (logvar * 0.5)?.exp()?; + Ok(DiagonalGaussianDistribution { mean, std }) + } + + pub fn sample(&self) -> Result<Tensor> { + let sample = self.mean.randn_like(0., 1.); + &self.mean + &self.std * sample + } +} + +// https://github.com/huggingface/diffusers/blob/970e30606c2944e3286f56e8eb6d3dc6d1eb85f7/src/diffusers/models/vae.py#L485 +// This implementation is specific to the config used in stable-diffusion-v1-5 +// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json +#[derive(Debug)] +pub struct AutoEncoderKL { + encoder: Encoder, + decoder: Decoder, + quant_conv: nn::Conv2d, + post_quant_conv: nn::Conv2d, + pub config: AutoEncoderKLConfig, +} + +impl AutoEncoderKL { + pub fn new( + vs: nn::VarBuilder, + in_channels: usize, + out_channels: usize, + config: AutoEncoderKLConfig, + ) -> Result<Self> { + let latent_channels = config.latent_channels; + let encoder_cfg = EncoderConfig { + block_out_channels: config.block_out_channels.clone(), + layers_per_block: config.layers_per_block, + norm_num_groups: config.norm_num_groups, + double_z: true, + }; + let encoder = Encoder::new(vs.pp("encoder"), in_channels, latent_channels, encoder_cfg)?; + let decoder_cfg = DecoderConfig { + block_out_channels: config.block_out_channels.clone(), + layers_per_block: config.layers_per_block, + norm_num_groups: config.norm_num_groups, + }; + let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?; + let conv_cfg = Default::default(); + let quant_conv = nn::conv2d( + 2 * latent_channels, + 2 * latent_channels, + 1, + conv_cfg, + vs.pp("quant_conv"), + )?; + let post_quant_conv = nn::conv2d( + latent_channels, + latent_channels, + 1, + conv_cfg, + vs.pp("post_quant_conv"), + )?; + Ok(Self { + encoder, + decoder, + quant_conv, + post_quant_conv, + config, + }) + } + + /// Returns the distribution in the latent space. + pub fn encode(&self, xs: &Tensor) -> Result<DiagonalGaussianDistribution> { + let xs = self.encoder.forward(xs)?; + let parameters = self.quant_conv.forward(&xs)?; + DiagonalGaussianDistribution::new(¶meters) + } + + /// Takes as input some sampled values. + pub fn decode(&self, xs: &Tensor) -> Result<Tensor> { + let xs = self.post_quant_conv.forward(xs)?; + self.decoder.forward(&xs) + } +} diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs new file mode 100644 index 00000000..539ae89b --- /dev/null +++ b/candle-transformers/src/models/t5.rs @@ -0,0 +1,841 @@ +// T5 Text Model +// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py + +use candle::{DType, Device, Module, Result, Tensor, D}; +use candle_nn::{Activation, VarBuilder}; +use serde::Deserialize; +use std::sync::Arc; + +#[derive(Debug)] +struct Embedding { + inner: candle_nn::Embedding, + span: tracing::Span, +} + +impl Embedding { + fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> { + let inner = candle_nn::embedding(d1, d2, vb)?; + let span = tracing::span!(tracing::Level::TRACE, "embedding"); + Ok(Self { inner, span }) + } + + fn embeddings(&self) -> &Tensor { + self.inner.embeddings() + } +} + +impl Module for Embedding { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +#[derive(Debug)] +struct Linear { + inner: candle_nn::Linear, + span: tracing::Span, +} + +impl Linear { + fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> { + let inner = candle_nn::linear_no_bias(d1, d2, vb)?; + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Ok(Self { inner, span }) + } +} + +impl Module for Linear { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +fn default_relative_attention_max_distance() -> usize { + 128 +} + +fn default_is_decoder() -> bool { + false +} + +fn default_use_cache() -> bool { + true +} + +fn default_tie_word_embeddings() -> bool { + true +} + +fn get_mask(size: usize, device: &Device) -> Result<Tensor> { + let mask: Vec<_> = (0..size) + .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) + .collect(); + Tensor::from_slice(&mask, (size, size), device) +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + vocab_size: usize, + d_model: usize, + d_kv: usize, + d_ff: usize, + num_layers: usize, + num_decoder_layers: Option<usize>, + num_heads: usize, + relative_attention_num_buckets: usize, + #[serde(default = "default_relative_attention_max_distance")] + relative_attention_max_distance: usize, + dropout_rate: f64, + layer_norm_epsilon: f64, + initializer_factor: f64, + #[serde(default)] + feed_forward_proj: Activation, + #[serde(default = "default_tie_word_embeddings")] + tie_word_embeddings: bool, + #[serde(default = "default_is_decoder")] + is_decoder: bool, + is_encoder_decoder: bool, + #[serde(default = "default_use_cache")] + pub use_cache: bool, + pub pad_token_id: usize, + pub eos_token_id: usize, +} + +impl Default for Config { + fn default() -> Self { + Self { + vocab_size: 32128, + d_model: 512, + d_kv: 64, + d_ff: 2048, + num_layers: 6, + num_decoder_layers: None, + num_heads: 8, + relative_attention_num_buckets: 32, + relative_attention_max_distance: 128, + dropout_rate: 0.1, + layer_norm_epsilon: 1e-6, + initializer_factor: 1.0, + feed_forward_proj: Activation::Relu, + tie_word_embeddings: true, + is_decoder: false, + is_encoder_decoder: true, + use_cache: true, + pad_token_id: 0, + eos_token_id: 1, + } + } +} + +impl Config { + // https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L184 + pub fn musicgen_small() -> Self { + Self { + d_ff: 3072, + d_kv: 64, + d_model: 768, + dropout_rate: 0.1, + eos_token_id: 1, + feed_forward_proj: Activation::Relu, + tie_word_embeddings: true, + initializer_factor: 1.0, + is_decoder: false, + is_encoder_decoder: true, + layer_norm_epsilon: 1e-6, + num_decoder_layers: Some(12), + num_heads: 12, + num_layers: 12, + pad_token_id: 0, + relative_attention_max_distance: 128, + relative_attention_num_buckets: 32, + use_cache: true, + vocab_size: 32128, + } + } +} + +#[derive(Debug)] +struct T5LayerNorm { + weight: Tensor, + variance_epsilon: f64, + span: tracing::Span, +} + +impl T5LayerNorm { + fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let weight = vb.get(h, "weight")?; + Ok(Self { + weight, + variance_epsilon: eps, + span: tracing::span!(tracing::Level::TRACE, "layer-norm"), + }) + } +} + +impl Module for T5LayerNorm { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let dtype = xs.dtype(); + let xs_f32 = xs.to_dtype(DType::F32)?; + // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?; + let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?; + let xs = xs.to_dtype(dtype)?; + let xs = xs.broadcast_mul(&self.weight)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5DenseActDense { + wi: Linear, + wo: Linear, + act: Activation, + span: tracing::Span, +} + +impl T5DenseActDense { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let wi = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi"))?; + let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; + Ok(Self { + wi, + wo, + act: Activation::Relu, + span: tracing::span!(tracing::Level::TRACE, "dense-act-dense"), + }) + } +} + +impl Module for T5DenseActDense { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let xs = self.wi.forward(xs)?; + let xs = self.act.forward(&xs)?; + let xs = self.wo.forward(&xs)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5DenseGatedActDense { + wi_0: Linear, + wi_1: Linear, + wo: Linear, + act: Activation, + span: tracing::Span, +} + +impl T5DenseGatedActDense { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let wi_0 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?; + let wi_1 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?; + let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?; + Ok(Self { + wi_0, + wi_1, + wo, + act: Activation::NewGelu, + span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"), + }) + } +} + +impl Module for T5DenseGatedActDense { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?; + let hidden_linear = self.wi_1.forward(xs)?; + let xs = hidden_gelu.broadcast_mul(&hidden_linear)?; + let xs = self.wo.forward(&xs)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5LayerFF { + dense_act: Option<T5DenseActDense>, + gated_dense_act: Option<T5DenseGatedActDense>, + layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5LayerFF { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let layer_norm = + T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; + let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu { + ( + None, + Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?), + ) + } else { + ( + Some(T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?), + None, + ) + }; + Ok(Self { + dense_act, + gated_dense_act, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "layer-ff"), + }) + } +} + +impl Module for T5LayerFF { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let ys = self.layer_norm.forward(xs)?; + let ys = match &self.dense_act { + Some(dense_act) => dense_act.forward(&ys)?, + None => self.gated_dense_act.as_ref().unwrap().forward(&ys)?, + }; + let xs = (xs + ys)?; + Ok(xs) + } +} + +#[derive(Debug)] +struct T5Attention { + q: Linear, + k: Linear, + v: Linear, + o: Linear, + n_heads: usize, + d_kv: usize, + relative_attention_bias: Option<Embedding>, + relative_attention_num_buckets: usize, + relative_attention_max_distance: usize, + inner_dim: usize, + use_cache: bool, + kv_cache: Option<(Tensor, Tensor)>, + span: tracing::Span, + span_cache: tracing::Span, + span_mm: tracing::Span, + span_sm: tracing::Span, +} + +impl T5Attention { + fn load( + has_relative_attention_bias: bool, + decoder: bool, + vb: VarBuilder, + cfg: &Config, + ) -> Result<Self> { + let inner_dim = cfg.num_heads * cfg.d_kv; + let q = Linear::new(cfg.d_model, inner_dim, vb.pp("q"))?; + let k = Linear::new(cfg.d_model, inner_dim, vb.pp("k"))?; + let v = Linear::new(cfg.d_model, inner_dim, vb.pp("v"))?; + let o = Linear::new(inner_dim, cfg.d_model, vb.pp("o"))?; + let relative_attention_bias = if has_relative_attention_bias { + let emb = Embedding::new( + cfg.relative_attention_num_buckets, + cfg.num_heads, + vb.pp("relative_attention_bias"), + )?; + Some(emb) + } else { + None + }; + Ok(Self { + q, + k, + v, + o, + n_heads: cfg.num_heads, + d_kv: cfg.d_kv, + relative_attention_bias, + relative_attention_num_buckets: cfg.relative_attention_num_buckets, + relative_attention_max_distance: cfg.relative_attention_max_distance, + inner_dim, + use_cache: cfg.use_cache && decoder, + kv_cache: None, + span: tracing::span!(tracing::Level::TRACE, "attention"), + span_cache: tracing::span!(tracing::Level::TRACE, "attention-cache"), + span_mm: tracing::span!(tracing::Level::TRACE, "attention-mm"), + span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"), + }) + } + + fn forward( + &mut self, + xs: &Tensor, + position_bias: Option<&Tensor>, + key_value_states: Option<&Tensor>, + mask: Option<&Tensor>, + ) -> Result<(Tensor, Option<Tensor>)> { + // Performs Self-attention (if key_value_states is None) or attention + // over source sentence (provided by key_value_states). + let _enter = self.span.enter(); + let kv_input = match key_value_states { + None => xs, + Some(key_value_states) => key_value_states, + }; + let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?); + let kv_len = kv_input.dim(1)?; + let q = self.q.forward(xs)?; + let k = self.k.forward(kv_input)?; + let v = self.v.forward(kv_input)?; + let q = q + .reshape((b_sz, q_len, self.n_heads, self.d_kv))? + .transpose(1, 2)? + .contiguous()?; + let mut k = k + .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? + .transpose(1, 2)? + .contiguous()?; + let mut v = v + .reshape((b_sz, kv_len, self.n_heads, self.d_kv))? + .transpose(1, 2)? + .contiguous()?; + + if self.use_cache { + let _enter = self.span_cache.enter(); + if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache { + k = Tensor::cat(&[kv_cache_k, &k], 2)?.contiguous()?; + v = Tensor::cat(&[kv_cache_v, &v], 2)?.contiguous()?; + }; + self.kv_cache = Some((k.clone(), v.clone())); + }; + // TODO: Use flash_attn. + let scores = { + let _enter = self.span_mm.enter(); + q.matmul(&k.t()?)? + }; + let scores = match mask { + None => scores, + Some(mask) => masked_fill( + &scores, + &mask + .unsqueeze(0)? + .unsqueeze(0)? + .repeat((b_sz, self.n_heads))?, + f32::NEG_INFINITY, + )?, + }; + + let (scores, position_bias) = match position_bias { + Some(position_bias) => ( + scores.broadcast_add(position_bias)?, + Some(position_bias.clone()), + ), + None => match &self.relative_attention_bias { + None => (scores, None), + Some(relative_attention_bias) => { + // This only handles the bidirectional case. + let kv_len = k.dim(2)?; + let (q_start, q_end) = match self.use_cache { + true => ((kv_len - q_len) as u32, kv_len as u32), + false => (0_u32, kv_len as u32), + }; + let num_buckets = self.relative_attention_num_buckets as u32 / 2; + let max_exact = num_buckets / 2; + let relative_position = (q_start..q_end) + .map(|i| { + (0..kv_len as u32) + .map(|j| { + if i < j { + if j - i < max_exact { + j - i + num_buckets + } else { + let b = f32::log( + (j - i) as f32 / max_exact as f32, + self.relative_attention_max_distance as f32 + / max_exact as f32, + ) * (num_buckets - max_exact) as f32; + u32::min( + max_exact + num_buckets + b as u32, + self.relative_attention_num_buckets as u32 - 1, + ) + } + } else if i - j < max_exact { + i - j + } else { + let b = f32::log( + (i - j) as f32 / max_exact as f32, + self.relative_attention_max_distance as f32 + / max_exact as f32, + ) * (num_buckets - max_exact) as f32; + max_exact + b as u32 + } + }) + .collect::<Vec<u32>>() + }) + .collect::<Vec<Vec<_>>>(); + let relative_buckets = Tensor::new(relative_position, q.device())?; + let position_bias = relative_attention_bias + .forward(&relative_buckets)? + .permute((2, 0, 1))? + .unsqueeze(0)?; + (scores.broadcast_add(&position_bias)?, Some(position_bias)) + // TODO: position_bias_masked? + } + }, + }; + + let attn_weights = { + let _enter = self.span_sm.enter(); + candle_nn::ops::softmax(&scores, D::Minus1)? + }; + let attn_output = attn_weights.matmul(&v)?; + let attn_output = attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, self.inner_dim))?; + let attn_output = self.o.forward(&attn_output)?; + Ok((attn_output, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } +} + +#[derive(Debug)] +struct T5LayerSelfAttention { + self_attention: T5Attention, + layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5LayerSelfAttention { + fn load(h: bool, d: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { + let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?; + let layer_norm = + T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; + Ok(Self { + self_attention, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "self-attn"), + }) + } + + fn forward( + &mut self, + xs: &Tensor, + position_bias: Option<&Tensor>, + mask: Option<&Tensor>, + ) -> Result<(Tensor, Option<Tensor>)> { + let _enter = self.span.enter(); + let normed_xs = self.layer_norm.forward(xs)?; + let (ys, position_bias) = + self.self_attention + .forward(&normed_xs, position_bias, None, mask)?; + let ys = (xs + ys)?; + Ok((ys, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.self_attention.clear_kv_cache() + } +} + +#[derive(Debug)] +struct T5LayerCrossAttention { + cross_attention: T5Attention, + layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5LayerCrossAttention { + fn load(decoder: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> { + let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?; + let layer_norm = + T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?; + Ok(Self { + cross_attention, + layer_norm, + span: tracing::span!(tracing::Level::TRACE, "cross-attn"), + }) + } + + fn forward( + &mut self, + hidden_states: &Tensor, + position_bias: Option<&Tensor>, + key_value_states: &Tensor, + ) -> Result<(Tensor, Option<Tensor>)> { + let _enter = self.span.enter(); + let normed_hidden_states = self.layer_norm.forward(hidden_states)?; + let (ys, position_bias) = self.cross_attention.forward( + &normed_hidden_states, + position_bias, + Some(key_value_states), + None, + )?; + let ys = (hidden_states + ys)?; + Ok((ys, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.cross_attention.clear_kv_cache() + } +} + +#[derive(Debug)] +struct T5Block { + self_attn: T5LayerSelfAttention, + cross_attn: Option<T5LayerCrossAttention>, + ff: T5LayerFF, + span: tracing::Span, +} + +impl T5Block { + fn load( + has_relative_attention_bias: bool, + decoder: bool, + vb: VarBuilder, + cfg: &Config, + ) -> Result<Self> { + let vb = vb.pp("layer"); + let self_attn = + T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?; + let cross_attn = if cfg.is_decoder { + Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?) + } else { + None + }; + let ff_i = if cross_attn.is_some() { 2 } else { 1 }; + let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?; + Ok(Self { + self_attn, + cross_attn, + ff, + span: tracing::span!(tracing::Level::TRACE, "block"), + }) + } + + fn forward( + &mut self, + xs: &Tensor, + position_bias: Option<&Tensor>, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<(Tensor, Option<Tensor>)> { + let _enter = self.span.enter(); + // TODO: Cache masks + let mask = match self.cross_attn.is_some() { + true => { + let mask_len = xs.dim(1)?; + // If the input seq length is 1, no need for a mask, this is also helpful to avoid shape + // issues when using the KV cache in the decoder. + if mask_len <= 1 { + None + } else { + Some(get_mask(mask_len, xs.device())?) + } + } + false => None, + }; + let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?; + // TODO: clamp for f16? + if let Some(cross_attn) = &mut self.cross_attn { + (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?; + // TODO: clamp for f16? + } + let xs = self.ff.forward(&xs)?; + // TODO: clamp for f16? + Ok((xs, position_bias)) + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache(); + self.cross_attn.iter_mut().for_each(|c| c.clear_kv_cache()); + } +} + +#[derive(Debug)] +struct T5Stack { + block: Vec<T5Block>, + shared: Arc<Embedding>, + final_layer_norm: T5LayerNorm, + span: tracing::Span, +} + +impl T5Stack { + fn load(decoder: bool, vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> { + let block = (0..cfg.num_layers) + .map(|i| T5Block::load(i == 0, decoder, vb.pp(&format!("block.{i}")), cfg)) + .collect::<Result<Vec<_>>>()?; + let final_layer_norm = T5LayerNorm::load( + cfg.d_model, + cfg.layer_norm_epsilon, + vb.pp("final_layer_norm"), + )?; + Ok(Self { + block, + shared: shared.clone(), + final_layer_norm, + span: tracing::span!(tracing::Level::TRACE, "stack"), + }) + } + + fn forward( + &mut self, + input_ids: &Tensor, + encoder_hidden_states: Option<&Tensor>, + ) -> Result<Tensor> { + let _enter = self.span.enter(); + let input_embeds = self.shared.as_ref().forward(input_ids)?; + let mut hidden_states = input_embeds; + let mut position_bias = None; + for block in self.block.iter_mut() { + (hidden_states, position_bias) = block.forward( + &hidden_states, + position_bias.as_ref(), + encoder_hidden_states, + )? + } + self.final_layer_norm.forward(&hidden_states) + } + + fn clear_kv_cache(&mut self) { + self.block.iter_mut().for_each(|b| b.clear_kv_cache()) + } +} + +#[derive(Debug)] +pub struct T5EncoderModel { + encoder: T5Stack, + device: Device, + span: tracing::Span, +} + +impl T5EncoderModel { + pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let shared = Arc::new(shared); + let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?; + Ok(Self { + encoder, + device: vb.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "encoder"), + }) + } + + pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.encoder.forward(input_ids, None) + } + + pub fn device(&self) -> &Device { + &self.device + } + + pub fn clear_kv_cache(&mut self) { + self.encoder.clear_kv_cache() + } +} + +#[derive(Debug)] +pub struct T5ForConditionalGeneration { + encoder: T5Stack, + decoder: T5Stack, + d_model: usize, + tie_word_embeddings: bool, + lm_head: Option<Linear>, + shared: Arc<Embedding>, + device: Device, + span_decode: tracing::Span, + span_decode_head: tracing::Span, +} + +impl T5ForConditionalGeneration { + pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + assert!(cfg.is_encoder_decoder); + let d_model = cfg.d_model; + let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let shared = Arc::new(shared); + + let mut encoder_cfg = cfg.clone(); + encoder_cfg.is_decoder = false; + encoder_cfg.use_cache = false; + encoder_cfg.is_encoder_decoder = false; + let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, &encoder_cfg)?; + + let mut decoder_cfg = cfg.clone(); + decoder_cfg.is_decoder = true; + decoder_cfg.is_encoder_decoder = false; + decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers); + let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?; + + let tie_word_embeddings = cfg.tie_word_embeddings; + let lm_head = if tie_word_embeddings { + None + } else { + Some(Linear::new(cfg.d_model, cfg.vocab_size, vb.pp("lm_head"))?) + }; + + Ok(Self { + encoder, + decoder, + d_model, + tie_word_embeddings, + lm_head, + shared, + device: vb.device().clone(), + span_decode: tracing::span!(tracing::Level::TRACE, "decode"), + span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"), + }) + } + + pub fn encode(&mut self, input_ids: &Tensor) -> Result<Tensor> { + self.encoder.forward(input_ids, None) + } + + pub fn decode( + &mut self, + decoder_input_ids: &Tensor, + encoder_output: &Tensor, + ) -> Result<Tensor> { + let _enter = self.span_decode.enter(); + let decoder_output = self + .decoder + .forward(decoder_input_ids, Some(encoder_output))?; + + let scaling_factor = if self.tie_word_embeddings { + // Rescale output before projecting on vocab + // See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + (self.d_model as f64).sqrt() + } else { + 1.0 + }; + let sequence_output = ((decoder_output + .narrow(1, decoder_output.dim(1)? - 1, 1)? + .squeeze(1)?) + * scaling_factor)?; + let output = { + let _enter = self.span_decode_head.enter(); + match self.lm_head { + None => sequence_output.matmul(&self.shared.embeddings().t()?)?, + Some(ref lm_head) => lm_head.forward(&sequence_output)?, + } + }; + + // TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5) + Ok(output) + } + + pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> { + let encoder_output = self.encode(input_ids)?; + self.decode(decoder_input_ids, &encoder_output) + } + + pub fn device(&self) -> &Device { + &self.device + } + + pub fn clear_kv_cache(&mut self) { + self.encoder.clear_kv_cache(); + self.decoder.clear_kv_cache(); + } +} diff --git a/candle-transformers/src/models/whisper/audio.rs b/candle-transformers/src/models/whisper/audio.rs new file mode 100644 index 00000000..4e01de32 --- /dev/null +++ b/candle-transformers/src/models/whisper/audio.rs @@ -0,0 +1,210 @@ +// Audio processing code, adapted from whisper.cpp +// https://github.com/ggerganov/whisper.cpp + +pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {} + +impl Float for f32 {} +impl Float for f64 {} + +// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357 +fn fft<T: Float>(inp: &[T]) -> Vec<T> { + let n = inp.len(); + let zero = T::zero(); + if n == 1 { + return vec![inp[0], zero]; + } + if n % 2 == 1 { + return dft(inp); + } + let mut out = vec![zero; n * 2]; + + let mut even = Vec::with_capacity(n / 2); + let mut odd = Vec::with_capacity(n / 2); + + for (i, &inp) in inp.iter().enumerate() { + if i % 2 == 0 { + even.push(inp) + } else { + odd.push(inp); + } + } + + let even_fft = fft(&even); + let odd_fft = fft(&odd); + + let two_pi = T::PI() + T::PI(); + let n_t = T::from(n).unwrap(); + for k in 0..n / 2 { + let k_t = T::from(k).unwrap(); + let theta = two_pi * k_t / n_t; + let re = theta.cos(); + let im = -theta.sin(); + + let re_odd = odd_fft[2 * k]; + let im_odd = odd_fft[2 * k + 1]; + + out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd; + out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd; + + out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd; + out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd; + } + out +} + +// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337 +fn dft<T: Float>(inp: &[T]) -> Vec<T> { + let zero = T::zero(); + let n = inp.len(); + let two_pi = T::PI() + T::PI(); + + let mut out = Vec::new(); + out.reserve(2 * n); + let n_t = T::from(n).unwrap(); + for k in 0..n { + let k_t = T::from(k).unwrap(); + let mut re = zero; + let mut im = zero; + + for (j, &inp) in inp.iter().enumerate() { + let j_t = T::from(j).unwrap(); + let angle = two_pi * k_t * j_t / n_t; + re += inp * angle.cos(); + im -= inp * angle.sin(); + } + + out.push(re); + out.push(im); + } + out +} + +#[allow(clippy::too_many_arguments)] +// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2414 +fn log_mel_spectrogram_w<T: Float>( + ith: usize, + hann: &[T], + samples: &[T], + filters: &[T], + fft_size: usize, + fft_step: usize, + speed_up: bool, + n_len: usize, + n_mel: usize, + n_threads: usize, +) -> Vec<T> { + let n_fft = if speed_up { + 1 + fft_size / 4 + } else { + 1 + fft_size / 2 + }; + + let zero = T::zero(); + let half = T::from(0.5).unwrap(); + let mut fft_in = vec![zero; fft_size]; + let mut mel = vec![zero; n_len * n_mel]; + + for i in (ith..n_len).step_by(n_threads) { + let offset = i * fft_step; + + // apply Hanning window + for j in 0..fft_size { + fft_in[j] = if offset + j < samples.len() { + hann[j] * samples[offset + j] + } else { + zero + } + } + + // FFT -> mag^2 + let mut fft_out: Vec<T> = fft(&fft_in); + + for j in 0..fft_size { + fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1]; + } + for j in 1..fft_size / 2 { + let v = fft_out[fft_size - j]; + fft_out[j] += v; + } + + if speed_up { + // scale down in the frequency domain results in a speed up in the time domain + for j in 0..n_fft { + fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]); + } + } + + // mel spectrogram + for j in 0..n_mel { + let mut sum = zero; + for k in 0..n_fft { + sum += fft_out[k] * filters[j * n_fft + k]; + } + mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10(); + } + } + mel +} + +fn log_mel_spectrogram_<T: Float + std::fmt::Display>( + samples: &[T], + filters: &[T], + fft_size: usize, + fft_step: usize, + n_mel: usize, + speed_up: bool, +) -> Vec<T> { + let zero = T::zero(); + let two_pi = T::PI() + T::PI(); + let half = T::from(0.5).unwrap(); + let one = T::from(1.0).unwrap(); + let four = T::from(4.0).unwrap(); + let fft_size_t = T::from(fft_size).unwrap(); + + let hann: Vec<T> = (0..fft_size) + .map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos())) + .collect(); + let n_len = samples.len() / fft_step; + + // pad audio with at least one extra chunk of zeros + let pad = 100 * super::CHUNK_LENGTH / 2; + let n_len = if n_len % pad != 0 { + (n_len / pad + 1) * pad + } else { + n_len + }; + let n_len = n_len + pad; + let samples = { + let mut samples_padded = samples.to_vec(); + let to_add = n_len * fft_step - samples.len(); + samples_padded.extend(std::iter::repeat(zero).take(to_add)); + samples_padded + }; + + // Use a single thread for now. + let mut mel = log_mel_spectrogram_w( + 0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1, + ); + let mmax = mel + .iter() + .max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater)) + .copied() + .unwrap_or(zero) + - T::from(8).unwrap(); + for m in mel.iter_mut() { + let v = T::max(*m, mmax); + *m = v / four + one + } + mel +} + +pub fn pcm_to_mel<T: Float + std::fmt::Display>(samples: &[T], filters: &[T]) -> Vec<T> { + log_mel_spectrogram_( + samples, + filters, + super::N_FFT, + super::HOP_LENGTH, + super::N_MELS, + false, + ) +} diff --git a/candle-transformers/src/models/whisper/mod.rs b/candle-transformers/src/models/whisper/mod.rs new file mode 100644 index 00000000..7dc8107b --- /dev/null +++ b/candle-transformers/src/models/whisper/mod.rs @@ -0,0 +1,26 @@ +pub mod audio; +pub mod model; + +pub const DTYPE: candle::DType = candle::DType::F32; + +// Audio parameters. +pub const SAMPLE_RATE: usize = 16000; +pub const N_FFT: usize = 400; +pub const N_MELS: usize = 80; +pub const HOP_LENGTH: usize = 160; +pub const CHUNK_LENGTH: usize = 30; +pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk +pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input + +pub const NO_SPEECH_THRESHOLD: f64 = 0.6; +pub const LOGPROB_THRESHOLD: f64 = -1.0; +pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]; +pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; + +// Tokenizer dependent bits. +pub const SOT_TOKEN: &str = "<|startoftranscript|>"; +pub const TRANSCRIBE_TOKEN: &str = "<|transcribe|>"; +pub const TRANSLATE_TOKEN: &str = "<|translate|>"; +pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>"; +pub const EOT_TOKEN: &str = "<|endoftext|>"; +pub const NO_SPEECH_TOKEN: &str = "<|nocaptions|>"; diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs new file mode 100644 index 00000000..d2eda796 --- /dev/null +++ b/candle-transformers/src/models/whisper/model.rs @@ -0,0 +1,416 @@ +use candle::{Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; +use serde::Deserialize; + +// The names in comments correspond to the original implementation: +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17 +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + pub num_mel_bins: usize, // n_mels + pub max_source_positions: usize, // n_audio_ctx + pub d_model: usize, // n_audio_state + pub encoder_attention_heads: usize, // n_audio_head + pub encoder_layers: usize, // n_audio_layer + pub vocab_size: usize, // n_vocab + pub max_target_positions: usize, // n_text_ctx + // pub n_text_state: usize, + pub decoder_attention_heads: usize, // n_text_head + pub decoder_layers: usize, // n_text_layer + pub suppress_tokens: Vec<u32>, +} + +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; + Ok(Embedding::new(embeddings, hidden_size)) +} +// +// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting +// model. +#[derive(Debug)] +pub struct Linear { + inner: candle_nn::Linear, + span: tracing::Span, +} + +impl Linear { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { + let span = tracing::span!(tracing::Level::TRACE, "linear"); + let inner = candle_nn::linear(size1, size2, vb)?; + Ok(Linear { inner, span }) +} + +fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { + let span = tracing::span!(tracing::Level::TRACE, "linear"); + let inner = candle_nn::linear_no_bias(size1, size2, vb)?; + Ok(Linear { inner, span }) +} + +fn conv1d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + config: Conv1dConfig, + vb: VarBuilder, +) -> Result<Conv1d> { + let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; + let bias = vb.get(out_channels, "bias")?; + Ok(Conv1d::new(weight, Some(bias), config)) +} + +fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> { + let weight = vb.get(size, "weight")?; + let bias = vb.get(size, "bias")?; + Ok(LayerNorm::new(weight, bias, 1e-5)) +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 +struct MultiHeadAttention { + query: Linear, + key: Linear, + value: Linear, + out: Linear, + n_head: usize, + span: tracing::Span, + softmax_span: tracing::Span, + matmul_span: tracing::Span, + kv_cache: Option<(Tensor, Tensor)>, +} + +impl MultiHeadAttention { + fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn"); + let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax"); + let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul"); + let query = linear(n_state, n_state, vb.pp("q_proj"))?; + let value = linear(n_state, n_state, vb.pp("v_proj"))?; + let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?; + let out = linear(n_state, n_state, vb.pp("out_proj"))?; + Ok(Self { + query, + key, + value, + out, + n_head, + span, + softmax_span, + matmul_span, + kv_cache: None, + }) + } + + fn forward( + &mut self, + x: &Tensor, + xa: Option<&Tensor>, + mask: Option<&Tensor>, + flush_cache: bool, + ) -> Result<Tensor> { + let _enter = self.span.enter(); + let q = self.query.forward(x)?; + let (k, v) = match xa { + None => { + let k = self.key.forward(x)?; + let v = self.value.forward(x)?; + (k, v) + } + Some(x) => { + if flush_cache { + self.kv_cache = None; + } + if let Some((k, v)) = &self.kv_cache { + (k.clone(), v.clone()) + } else { + let k = self.key.forward(x)?; + let v = self.value.forward(x)?; + self.kv_cache = Some((k.clone(), v.clone())); + (k, v) + } + } + }; + let wv = self.qkv_attention(&q, &k, &v, mask)?; + let out = self.out.forward(&wv)?; + Ok(out) + } + + fn reshape_head(&self, x: &Tensor) -> Result<Tensor> { + let (n_batch, n_ctx, n_state) = x.dims3()?; + let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; + x.reshape(target_dims)?.transpose(1, 2) + } + + fn qkv_attention( + &self, + q: &Tensor, + k: &Tensor, + v: &Tensor, + mask: Option<&Tensor>, + ) -> Result<Tensor> { + let (_, n_ctx, n_state) = q.dims3()?; + let scale = ((n_state / self.n_head) as f64).powf(-0.25); + let q = (self.reshape_head(q)? * scale)?; + let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?; + let v = self.reshape_head(v)?.contiguous()?; + let mut qk = { + let _enter = self.matmul_span.enter(); + q.matmul(&k)? + }; + if let Some(mask) = mask { + let mask = mask.i((0..n_ctx, 0..n_ctx))?; + qk = qk.broadcast_add(&mask)? + } + let w = { + let _enter = self.softmax_span.enter(); + candle_nn::ops::softmax_last_dim(&qk)? + }; + let wv = { + let _enter = self.matmul_span.enter(); + w.matmul(&v)? + } + .transpose(1, 2)? + .flatten_from(2)?; + Ok(wv) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 +struct ResidualAttentionBlock { + attn: MultiHeadAttention, + attn_ln: LayerNorm, + cross_attn: Option<(MultiHeadAttention, LayerNorm)>, + mlp_linear1: Linear, + mlp_linear2: Linear, + mlp_ln: LayerNorm, + span: tracing::Span, +} + +impl ResidualAttentionBlock { + fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "residual-attn"); + let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?; + let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?; + let cross_attn = if ca { + let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?; + let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?; + Some((cross_attn, cross_attn_ln)) + } else { + None + }; + let n_mlp = n_state * 4; + let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?; + let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?; + let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?; + Ok(Self { + attn, + attn_ln, + cross_attn, + mlp_linear1, + mlp_linear2, + mlp_ln, + span, + }) + } + + fn forward( + &mut self, + x: &Tensor, + xa: Option<&Tensor>, + mask: Option<&Tensor>, + flush_kv_cache: bool, + ) -> Result<Tensor> { + let _enter = self.span.enter(); + let attn = self + .attn + .forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?; + let mut x = (x + attn)?; + if let Some((attn, ln)) = &mut self.cross_attn { + x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?; + } + let mlp = self.mlp_linear2.forward( + &self + .mlp_linear1 + .forward(&self.mlp_ln.forward(&x)?)? + .gelu()?, + )?; + x + mlp + } +} + +fn sinusoids(length: usize, channels: usize) -> Result<Tensor> { + let max_timescale = 10000f32; + let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32; + let inv_timescales: Vec<_> = (0..channels / 2) + .map(|i| (i as f32 * (-log_timescale_increment)).exp()) + .collect(); + let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?; + let arange = Tensor::arange(0, length as u32, &Device::Cpu)? + .to_dtype(candle::DType::F32)? + .unsqueeze(1)?; + let sh = (length, channels / 2); + let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?; + let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?; + Ok(sincos) +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 +pub struct AudioEncoder { + conv1: Conv1d, + conv2: Conv1d, + positional_embedding: Tensor, + blocks: Vec<ResidualAttentionBlock>, + ln_post: LayerNorm, + span: tracing::Span, + conv1_span: tracing::Span, + conv2_span: tracing::Span, +} + +impl AudioEncoder { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "audio-encoder"); + let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1"); + let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2"); + let n_state = cfg.d_model; + let n_head = cfg.encoder_attention_heads; + let n_ctx = cfg.max_source_positions; + let cfg1 = Conv1dConfig { + padding: 1, + stride: 1, + groups: 1, + dilation: 1, + }; + let cfg2 = Conv1dConfig { + padding: 1, + stride: 2, + groups: 1, + dilation: 1, + }; + let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; + let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; + let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?; + let blocks = (0..cfg.encoder_layers) + .map(|i| { + ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}"))) + }) + .collect::<Result<Vec<_>>>()?; + let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?; + Ok(Self { + conv1, + conv2, + positional_embedding, + blocks, + ln_post, + conv1_span, + conv2_span, + span, + }) + } + + pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> { + let _enter = self.span.enter(); + let x = { + let _enter = self.conv1_span.enter(); + self.conv1.forward(x)?.gelu()? + }; + let x = { + let _enter = self.conv2_span.enter(); + self.conv2.forward(&x)?.gelu()? + }; + let x = x.transpose(1, 2)?; + let (_bsize, seq_len, _hidden) = x.dims3()?; + let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; + let mut x = x.broadcast_add(&positional_embedding)?; + for block in self.blocks.iter_mut() { + x = block.forward(&x, None, None, flush_kv_cache)? + } + let x = self.ln_post.forward(&x)?; + Ok(x) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 +pub struct TextDecoder { + token_embedding: Embedding, + positional_embedding: Tensor, + blocks: Vec<ResidualAttentionBlock>, + ln: LayerNorm, + mask: Tensor, + span: tracing::Span, + span_final: tracing::Span, +} + +impl TextDecoder { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "text-decoder"); + let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final"); + let n_state = cfg.d_model; + let n_head = cfg.decoder_attention_heads; + let n_ctx = cfg.max_target_positions; + let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?; + let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?; + let blocks = (0..cfg.decoder_layers) + .map(|i| { + ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}"))) + }) + .collect::<Result<Vec<_>>>()?; + let ln = layer_norm(n_state, vb.pp("layer_norm"))?; + let mask: Vec<_> = (0..n_ctx) + .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) + .collect(); + let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?; + Ok(Self { + token_embedding, + positional_embedding, + blocks, + ln, + mask, + span, + span_final, + }) + } + + pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> { + let _enter = self.span.enter(); + let last = x.dim(D::Minus1)?; + let token_embedding = self.token_embedding.forward(x)?; + let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; + let mut x = token_embedding.broadcast_add(&positional_embedding)?; + for block in self.blocks.iter_mut() { + x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?; + } + self.ln.forward(&x) + } + + pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> { + let b_size = x.dim(0)?; + let w = self.token_embedding.embeddings().broadcast_left(b_size)?; + let logits = { + let _enter = self.span_final.enter(); + x.matmul(&w.t()?)? + }; + Ok(logits) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 +pub struct Whisper { + pub encoder: AudioEncoder, + pub decoder: TextDecoder, + pub config: Config, +} + +impl Whisper { + pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> { + let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?; + let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?; + Ok(Self { + encoder, + decoder, + config, + }) + } +} diff --git a/candle-transformers/src/models/wuerstchen/attention_processor.rs b/candle-transformers/src/models/wuerstchen/attention_processor.rs new file mode 100644 index 00000000..0b90cb9d --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/attention_processor.rs @@ -0,0 +1,118 @@ +use candle::{Module, Result, Tensor}; +use candle_nn::{linear, Linear, VarBuilder}; + +// A simplified version of: +// https://github.com/huggingface/diffusers/blob/119ad2c3dc8a8fb8446a83f4bf6f20929487b47f/src/diffusers/models/attention_processor.py#L38 +#[derive(Debug)] +pub struct Attention { + to_q: Linear, + to_k: Linear, + to_v: Linear, + to_out: Linear, + heads: usize, + scale: f64, + use_flash_attn: bool, +} + +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result<Tensor> { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> { + unimplemented!("compile with '--features flash-attn'") +} + +impl Attention { + pub fn new( + query_dim: usize, + heads: usize, + dim_head: usize, + use_flash_attn: bool, + vb: VarBuilder, + ) -> Result<Self> { + let inner_dim = dim_head * heads; + let scale = 1.0 / f64::sqrt(dim_head as f64); + let to_q = linear(query_dim, inner_dim, vb.pp("to_q"))?; + let to_k = linear(query_dim, inner_dim, vb.pp("to_k"))?; + let to_v = linear(query_dim, inner_dim, vb.pp("to_v"))?; + let to_out = linear(inner_dim, query_dim, vb.pp("to_out.0"))?; + Ok(Self { + to_q, + to_k, + to_v, + to_out, + scale, + heads, + use_flash_attn, + }) + } + + fn batch_to_head_dim(&self, xs: &Tensor) -> Result<Tensor> { + let (b_size, seq_len, dim) = xs.dims3()?; + xs.reshape((b_size / self.heads, self.heads, seq_len, dim))? + .permute((0, 2, 1, 3))? + .reshape((b_size / self.heads, seq_len, dim * self.heads)) + } + + fn head_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> { + let (b_size, seq_len, dim) = xs.dims3()?; + xs.reshape((b_size, seq_len, self.heads, dim / self.heads))? + .permute((0, 2, 1, 3))? + .reshape((b_size * self.heads, seq_len, dim / self.heads)) + } + + fn get_attention_scores(&self, query: &Tensor, key: &Tensor) -> Result<Tensor> { + let attn_probs = (query.matmul(&key.t()?)? * self.scale)?; + candle_nn::ops::softmax_last_dim(&attn_probs) + } + + pub fn forward(&self, xs: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> { + let (b_size, channel, h, w) = xs.dims4()?; + let xs = xs.reshape((b_size, channel, h * w))?.t()?; + + let query = self.to_q.forward(&xs)?; + let key = self.to_k.forward(encoder_hidden_states)?; + let value = self.to_v.forward(encoder_hidden_states)?; + + let query = self.head_to_batch_dim(&query)?; + let key = self.head_to_batch_dim(&key)?; + let value = self.head_to_batch_dim(&value)?; + + let xs = if self.use_flash_attn { + let init_dtype = query.dtype(); + let q = query + .to_dtype(candle::DType::F16)? + .unsqueeze(0)? + .transpose(1, 2)?; + let k = key + .to_dtype(candle::DType::F16)? + .unsqueeze(0)? + .transpose(1, 2)?; + let v = value + .to_dtype(candle::DType::F16)? + .unsqueeze(0)? + .transpose(1, 2)?; + flash_attn(&q, &k, &v, self.scale as f32, false)? + .transpose(1, 2)? + .squeeze(0)? + .to_dtype(init_dtype)? + } else { + let attn_prs = self.get_attention_scores(&query, &key)?; + attn_prs.matmul(&value)? + }; + let xs = self.batch_to_head_dim(&xs)?; + + self.to_out + .forward(&xs)? + .t()? + .reshape((b_size, channel, h, w)) + } +} diff --git a/candle-transformers/src/models/wuerstchen/common.rs b/candle-transformers/src/models/wuerstchen/common.rs new file mode 100644 index 00000000..c89ec919 --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/common.rs @@ -0,0 +1,203 @@ +use candle::{DType, Module, Result, Tensor, D}; +use candle_nn::VarBuilder; + +// https://github.com/huggingface/diffusers/blob/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py#L22 +#[derive(Debug)] +pub struct WLayerNorm { + eps: f64, +} + +impl WLayerNorm { + pub fn new(_size: usize) -> Result<Self> { + Ok(Self { eps: 1e-6 }) + } +} + +impl Module for WLayerNorm { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs = xs.permute((0, 2, 3, 1))?; + + let x_dtype = xs.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + + let hidden_size = xs.dim(D::Minus1)?; + let xs = xs.to_dtype(internal_dtype)?; + let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let xs = xs.broadcast_sub(&mean_x)?; + let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)? + .to_dtype(x_dtype)? + .permute((0, 3, 1, 2)) + } +} + +#[derive(Debug)] +pub struct LayerNormNoWeights { + eps: f64, +} + +impl LayerNormNoWeights { + pub fn new(_size: usize) -> Result<Self> { + Ok(Self { eps: 1e-6 }) + } +} + +impl Module for LayerNormNoWeights { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let x_dtype = xs.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = xs.dim(D::Minus1)?; + let xs = xs.to_dtype(internal_dtype)?; + let mean_x = (xs.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let xs = xs.broadcast_sub(&mean_x)?; + let norm_x = (xs.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + xs.broadcast_div(&(norm_x + self.eps)?.sqrt()?)? + .to_dtype(x_dtype) + } +} + +#[derive(Debug)] +pub struct TimestepBlock { + mapper: candle_nn::Linear, +} + +impl TimestepBlock { + pub fn new(c: usize, c_timestep: usize, vb: VarBuilder) -> Result<Self> { + let mapper = candle_nn::linear(c_timestep, c * 2, vb.pp("mapper"))?; + Ok(Self { mapper }) + } + + pub fn forward(&self, xs: &Tensor, t: &Tensor) -> Result<Tensor> { + let ab = self + .mapper + .forward(t)? + .unsqueeze(2)? + .unsqueeze(3)? + .chunk(2, 1)?; + xs.broadcast_mul(&(&ab[0] + 1.)?)?.broadcast_add(&ab[1]) + } +} + +#[derive(Debug)] +pub struct GlobalResponseNorm { + gamma: Tensor, + beta: Tensor, +} + +impl GlobalResponseNorm { + pub fn new(dim: usize, vb: VarBuilder) -> Result<Self> { + let gamma = vb.get((1, 1, 1, dim), "gamma")?; + let beta = vb.get((1, 1, 1, dim), "beta")?; + Ok(Self { gamma, beta }) + } +} + +impl Module for GlobalResponseNorm { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let agg_norm = xs.sqr()?.sum_keepdim((1, 2))?.sqrt()?; + let stand_div_norm = + agg_norm.broadcast_div(&(agg_norm.mean_keepdim(D::Minus1)? + 1e-6)?)?; + xs.broadcast_mul(&stand_div_norm)? + .broadcast_mul(&self.gamma)? + .broadcast_add(&self.beta)? + + xs + } +} + +#[derive(Debug)] +pub struct ResBlock { + depthwise: candle_nn::Conv2d, + norm: WLayerNorm, + channelwise_lin1: candle_nn::Linear, + channelwise_grn: GlobalResponseNorm, + channelwise_lin2: candle_nn::Linear, +} + +impl ResBlock { + pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> { + let cfg = candle_nn::Conv2dConfig { + padding: ksize / 2, + groups: c, + ..Default::default() + }; + let depthwise = candle_nn::conv2d(c + c_skip, c, ksize, cfg, vb.pp("depthwise"))?; + let norm = WLayerNorm::new(c)?; + let channelwise_lin1 = candle_nn::linear(c, c * 4, vb.pp("channelwise.0"))?; + let channelwise_grn = GlobalResponseNorm::new(c * 4, vb.pp("channelwise.2"))?; + let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?; + Ok(Self { + depthwise, + norm, + channelwise_lin1, + channelwise_grn, + channelwise_lin2, + }) + } + + pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> { + let x_res = xs; + let xs = match x_skip { + None => xs.clone(), + Some(x_skip) => Tensor::cat(&[xs, x_skip], 1)?, + }; + let xs = xs + .apply(&self.depthwise)? + .apply(&self.norm)? + .permute((0, 2, 3, 1))?; + let xs = xs + .apply(&self.channelwise_lin1)? + .gelu_erf()? + .apply(&self.channelwise_grn)? + .apply(&self.channelwise_lin2)? + .permute((0, 3, 1, 2))?; + xs + x_res + } +} +use super::attention_processor::Attention; +#[derive(Debug)] +pub struct AttnBlock { + self_attn: bool, + norm: WLayerNorm, + attention: Attention, + kv_mapper_lin: candle_nn::Linear, +} + +impl AttnBlock { + pub fn new( + c: usize, + c_cond: usize, + nhead: usize, + self_attn: bool, + use_flash_attn: bool, + vb: VarBuilder, + ) -> Result<Self> { + let norm = WLayerNorm::new(c)?; + let attention = Attention::new(c, nhead, c / nhead, use_flash_attn, vb.pp("attention"))?; + let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?; + Ok(Self { + self_attn, + norm, + attention, + kv_mapper_lin, + }) + } + + pub fn forward(&self, xs: &Tensor, kv: &Tensor) -> Result<Tensor> { + let kv = candle_nn::ops::silu(kv)?.apply(&self.kv_mapper_lin)?; + let norm_xs = self.norm.forward(xs)?; + let kv = if self.self_attn { + let (b_size, channel, _, _) = xs.dims4()?; + let norm_xs = norm_xs.reshape((b_size, channel, ()))?.transpose(1, 2)?; + Tensor::cat(&[&norm_xs, &kv], 1)?.contiguous()? + } else { + kv + }; + xs + self.attention.forward(&norm_xs, &kv) + } +} diff --git a/candle-transformers/src/models/wuerstchen/ddpm.rs b/candle-transformers/src/models/wuerstchen/ddpm.rs new file mode 100644 index 00000000..9e69b868 --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/ddpm.rs @@ -0,0 +1,103 @@ +use candle::{Result, Tensor}; + +#[derive(Debug, Clone)] +pub struct DDPMWSchedulerConfig { + scaler: f64, + s: f64, +} + +impl Default for DDPMWSchedulerConfig { + fn default() -> Self { + Self { + scaler: 1f64, + s: 0.008f64, + } + } +} + +pub struct DDPMWScheduler { + init_alpha_cumprod: f64, + init_noise_sigma: f64, + timesteps: Vec<f64>, + pub config: DDPMWSchedulerConfig, +} + +impl DDPMWScheduler { + pub fn new(inference_steps: usize, config: DDPMWSchedulerConfig) -> Result<Self> { + let init_alpha_cumprod = (config.s / (1. + config.s) * std::f64::consts::PI) + .cos() + .powi(2); + let timesteps = (0..=inference_steps) + .map(|i| 1. - i as f64 / inference_steps as f64) + .collect::<Vec<_>>(); + Ok(Self { + init_alpha_cumprod, + init_noise_sigma: 1.0, + timesteps, + config, + }) + } + + pub fn timesteps(&self) -> &[f64] { + &self.timesteps + } + + fn alpha_cumprod(&self, t: f64) -> f64 { + let scaler = self.config.scaler; + let s = self.config.s; + let t = if scaler > 1. { + 1. - (1. - t).powf(scaler) + } else if scaler < 1. { + t.powf(scaler) + } else { + t + }; + let alpha_cumprod = ((t + s) / (1. + s) * std::f64::consts::PI * 0.5) + .cos() + .powi(2) + / self.init_alpha_cumprod; + alpha_cumprod.clamp(0.0001, 0.9999) + } + + fn previous_timestep(&self, ts: f64) -> f64 { + let index = self + .timesteps + .iter() + .enumerate() + .map(|(idx, v)| (idx, (v - ts).abs())) + .min_by(|x, y| x.1.total_cmp(&y.1)) + .unwrap() + .0; + self.timesteps[index + 1] + } + + /// Ensures interchangeability with schedulers that need to scale the denoising model input + /// depending on the current timestep. + pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor { + sample + } + + pub fn step(&self, model_output: &Tensor, ts: f64, sample: &Tensor) -> Result<Tensor> { + let prev_t = self.previous_timestep(ts); + + let alpha_cumprod = self.alpha_cumprod(ts); + let alpha_cumprod_prev = self.alpha_cumprod(prev_t); + let alpha = alpha_cumprod / alpha_cumprod_prev; + + let mu = (sample - model_output * ((1. - alpha) / (1. - alpha_cumprod).sqrt()))?; + let mu = (mu * (1. / alpha).sqrt())?; + + let std_noise = mu.randn_like(0., 1.)?; + let std = + std_noise * ((1. - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt(); + if prev_t == 0. { + Ok(mu) + } else { + mu + std + } + } + + pub fn init_noise_sigma(&self) -> f64 { + self.init_noise_sigma + } +} diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs new file mode 100644 index 00000000..64a48c8a --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -0,0 +1,396 @@ +use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm}; +use candle::{DType, Module, Result, Tensor, D}; +use candle_nn::VarBuilder; + +#[derive(Debug)] +pub struct ResBlockStageB { + depthwise: candle_nn::Conv2d, + norm: WLayerNorm, + channelwise_lin1: candle_nn::Linear, + channelwise_grn: GlobalResponseNorm, + channelwise_lin2: candle_nn::Linear, +} + +impl ResBlockStageB { + pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> { + let cfg = candle_nn::Conv2dConfig { + groups: c, + padding: ksize / 2, + ..Default::default() + }; + let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?; + let norm = WLayerNorm::new(c)?; + let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?; + let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?; + let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?; + Ok(Self { + depthwise, + norm, + channelwise_lin1, + channelwise_grn, + channelwise_lin2, + }) + } + + pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> { + let x_res = xs; + let xs = xs.apply(&self.depthwise)?.apply(&self.norm)?; + let xs = match x_skip { + None => xs.clone(), + Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?, + }; + let xs = xs + .permute((0, 2, 3, 1))? + .contiguous()? + .apply(&self.channelwise_lin1)? + .gelu()? + .apply(&self.channelwise_grn)? + .apply(&self.channelwise_lin2)? + .permute((0, 3, 1, 2))?; + xs + x_res + } +} + +#[derive(Debug)] +struct SubBlock { + res_block: ResBlockStageB, + ts_block: TimestepBlock, + attn_block: Option<AttnBlock>, +} + +#[derive(Debug)] +struct DownBlock { + layer_norm: Option<WLayerNorm>, + conv: Option<candle_nn::Conv2d>, + sub_blocks: Vec<SubBlock>, +} + +#[derive(Debug)] +struct UpBlock { + sub_blocks: Vec<SubBlock>, + layer_norm: Option<WLayerNorm>, + conv: Option<candle_nn::ConvTranspose2d>, +} + +#[derive(Debug)] +pub struct WDiffNeXt { + clip_mapper: candle_nn::Linear, + effnet_mappers: Vec<Option<candle_nn::Conv2d>>, + seq_norm: LayerNormNoWeights, + embedding_conv: candle_nn::Conv2d, + embedding_ln: WLayerNorm, + down_blocks: Vec<DownBlock>, + up_blocks: Vec<UpBlock>, + clf_ln: WLayerNorm, + clf_conv: candle_nn::Conv2d, + c_r: usize, + patch_size: usize, +} + +impl WDiffNeXt { + #[allow(clippy::too_many_arguments)] + pub fn new( + c_in: usize, + c_out: usize, + c_r: usize, + c_cond: usize, + clip_embd: usize, + patch_size: usize, + use_flash_attn: bool, + vb: VarBuilder, + ) -> Result<Self> { + const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280]; + const BLOCKS: [usize; 4] = [4, 4, 14, 4]; + const NHEAD: [usize; 4] = [1, 10, 20, 20]; + const INJECT_EFFNET: [bool; 4] = [false, true, true, true]; + const EFFNET_EMBD: usize = 16; + + let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?; + let mut effnet_mappers = Vec::with_capacity(2 * INJECT_EFFNET.len()); + let vb_e = vb.pp("effnet_mappers"); + for (i, &inject) in INJECT_EFFNET.iter().enumerate() { + let c = if inject { + Some(candle_nn::conv2d( + EFFNET_EMBD, + c_cond, + 1, + Default::default(), + vb_e.pp(i), + )?) + } else { + None + }; + effnet_mappers.push(c) + } + for (i, &inject) in INJECT_EFFNET.iter().rev().enumerate() { + let c = if inject { + Some(candle_nn::conv2d( + EFFNET_EMBD, + c_cond, + 1, + Default::default(), + vb_e.pp(i + INJECT_EFFNET.len()), + )?) + } else { + None + }; + effnet_mappers.push(c) + } + let seq_norm = LayerNormNoWeights::new(c_cond)?; + let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?; + let embedding_conv = candle_nn::conv2d( + c_in * patch_size * patch_size, + C_HIDDEN[0], + 1, + Default::default(), + vb.pp("embedding.1"), + )?; + + let mut down_blocks = Vec::with_capacity(C_HIDDEN.len()); + for (i, &c_hidden) in C_HIDDEN.iter().enumerate() { + let vb = vb.pp("down_blocks").pp(i); + let (layer_norm, conv, start_layer_i) = if i > 0 { + let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?; + let cfg = candle_nn::Conv2dConfig { + stride: 2, + ..Default::default() + }; + let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp("0.1"))?; + (Some(layer_norm), Some(conv), 1) + } else { + (None, None, 0) + }; + let mut sub_blocks = Vec::with_capacity(BLOCKS[i]); + let mut layer_i = start_layer_i; + for _j in 0..BLOCKS[i] { + let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 }; + let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?; + layer_i += 1; + let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?; + layer_i += 1; + let attn_block = if i == 0 { + None + } else { + let attn_block = AttnBlock::new( + c_hidden, + c_cond, + NHEAD[i], + true, + use_flash_attn, + vb.pp(layer_i), + )?; + layer_i += 1; + Some(attn_block) + }; + let sub_block = SubBlock { + res_block, + ts_block, + attn_block, + }; + sub_blocks.push(sub_block) + } + let down_block = DownBlock { + layer_norm, + conv, + sub_blocks, + }; + down_blocks.push(down_block) + } + + let mut up_blocks = Vec::with_capacity(C_HIDDEN.len()); + for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() { + let vb = vb.pp("up_blocks").pp(C_HIDDEN.len() - 1 - i); + let mut sub_blocks = Vec::with_capacity(BLOCKS[i]); + let mut layer_i = 0; + for j in 0..BLOCKS[i] { + let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 }; + let c_skip_res = if i < BLOCKS.len() - 1 && j == 0 { + c_hidden + c_skip + } else { + c_skip + }; + let res_block = ResBlockStageB::new(c_hidden, c_skip_res, 3, vb.pp(layer_i))?; + layer_i += 1; + let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?; + layer_i += 1; + let attn_block = if i == 0 { + None + } else { + let attn_block = AttnBlock::new( + c_hidden, + c_cond, + NHEAD[i], + true, + use_flash_attn, + vb.pp(layer_i), + )?; + layer_i += 1; + Some(attn_block) + }; + let sub_block = SubBlock { + res_block, + ts_block, + attn_block, + }; + sub_blocks.push(sub_block) + } + let (layer_norm, conv) = if i > 0 { + let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?; + let cfg = candle_nn::ConvTranspose2dConfig { + stride: 2, + ..Default::default() + }; + let conv = candle_nn::conv_transpose2d( + c_hidden, + C_HIDDEN[i - 1], + 2, + cfg, + vb.pp(layer_i).pp(1), + )?; + (Some(layer_norm), Some(conv)) + } else { + (None, None) + }; + let up_block = UpBlock { + layer_norm, + conv, + sub_blocks, + }; + up_blocks.push(up_block) + } + + let clf_ln = WLayerNorm::new(C_HIDDEN[0])?; + let clf_conv = candle_nn::conv2d( + C_HIDDEN[0], + 2 * c_out * patch_size * patch_size, + 1, + Default::default(), + vb.pp("clf.1"), + )?; + Ok(Self { + clip_mapper, + effnet_mappers, + seq_norm, + embedding_conv, + embedding_ln, + down_blocks, + up_blocks, + clf_ln, + clf_conv, + c_r, + patch_size, + }) + } + + fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> { + const MAX_POSITIONS: usize = 10000; + let r = (r * MAX_POSITIONS as f64)?; + let half_dim = self.c_r / 2; + let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64; + let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)? + * -emb)? + .exp()?; + let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?; + let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?; + let emb = if self.c_r % 2 == 1 { + emb.pad_with_zeros(D::Minus1, 0, 1)? + } else { + emb + }; + emb.to_dtype(r.dtype()) + } + + fn gen_c_embeddings(&self, clip: &Tensor) -> Result<Tensor> { + clip.apply(&self.clip_mapper)?.apply(&self.seq_norm) + } + + pub fn forward( + &self, + xs: &Tensor, + r: &Tensor, + effnet: &Tensor, + clip: Option<&Tensor>, + ) -> Result<Tensor> { + const EPS: f64 = 1e-3; + + let r_embed = self.gen_r_embedding(r)?; + let clip = match clip { + None => None, + Some(clip) => Some(self.gen_c_embeddings(clip)?), + }; + let x_in = xs; + + let mut xs = xs + .apply(&|xs: &_| candle_nn::ops::pixel_unshuffle(xs, self.patch_size))? + .apply(&self.embedding_conv)? + .apply(&self.embedding_ln)?; + + let mut level_outputs = Vec::new(); + for (i, down_block) in self.down_blocks.iter().enumerate() { + if let Some(ln) = &down_block.layer_norm { + xs = xs.apply(ln)? + } + if let Some(conv) = &down_block.conv { + xs = xs.apply(conv)? + } + let skip = match &self.effnet_mappers[i] { + None => None, + Some(m) => { + let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?; + Some(m.forward(&effnet)?) + } + }; + for block in down_block.sub_blocks.iter() { + xs = block.res_block.forward(&xs, skip.as_ref())?; + xs = block.ts_block.forward(&xs, &r_embed)?; + if let Some(attn_block) = &block.attn_block { + xs = attn_block.forward(&xs, clip.as_ref().unwrap())?; + } + } + level_outputs.push(xs.clone()) + } + level_outputs.reverse(); + let mut xs = level_outputs[0].clone(); + + for (i, up_block) in self.up_blocks.iter().enumerate() { + let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] { + None => None, + Some(m) => { + let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?; + Some(m.forward(&effnet)?) + } + }; + for (j, block) in up_block.sub_blocks.iter().enumerate() { + let skip = if j == 0 && i > 0 { + Some(&level_outputs[i]) + } else { + None + }; + let skip = match (skip, effnet_c.as_ref()) { + (Some(skip), Some(effnet_c)) => Some(Tensor::cat(&[skip, effnet_c], 1)?), + (None, Some(skip)) | (Some(skip), None) => Some(skip.clone()), + (None, None) => None, + }; + xs = block.res_block.forward(&xs, skip.as_ref())?; + xs = block.ts_block.forward(&xs, &r_embed)?; + if let Some(attn_block) = &block.attn_block { + xs = attn_block.forward(&xs, clip.as_ref().unwrap())?; + } + } + if let Some(ln) = &up_block.layer_norm { + xs = xs.apply(ln)? + } + if let Some(conv) = &up_block.conv { + xs = xs.apply(conv)? + } + } + + let ab = xs + .apply(&self.clf_ln)? + .apply(&self.clf_conv)? + .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))? + .chunk(2, 1)?; + let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?; + (x_in - &ab[0])? / b + } +} diff --git a/candle-transformers/src/models/wuerstchen/mod.rs b/candle-transformers/src/models/wuerstchen/mod.rs new file mode 100644 index 00000000..7b076f06 --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/mod.rs @@ -0,0 +1,6 @@ +pub mod attention_processor; +pub mod common; +pub mod ddpm; +pub mod diffnext; +pub mod paella_vq; +pub mod prior; diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs new file mode 100644 index 00000000..4a69cca0 --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs @@ -0,0 +1,211 @@ +use super::common::LayerNormNoWeights; +use candle::{Module, Result, Tensor}; +use candle_nn::VarBuilder; + +#[derive(Debug)] +pub struct MixingResidualBlock { + norm1: LayerNormNoWeights, + depthwise_conv: candle_nn::Conv2d, + norm2: LayerNormNoWeights, + channelwise_lin1: candle_nn::Linear, + channelwise_lin2: candle_nn::Linear, + gammas: Vec<f32>, +} + +impl MixingResidualBlock { + pub fn new(inp: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> { + let norm1 = LayerNormNoWeights::new(inp)?; + let norm2 = LayerNormNoWeights::new(inp)?; + let cfg = candle_nn::Conv2dConfig { + groups: inp, + ..Default::default() + }; + let depthwise_conv = candle_nn::conv2d(inp, inp, 3, cfg, vb.pp("depthwise.1"))?; + let channelwise_lin1 = candle_nn::linear(inp, embed_dim, vb.pp("channelwise.0"))?; + let channelwise_lin2 = candle_nn::linear(embed_dim, inp, vb.pp("channelwise.2"))?; + let gammas = vb.get(6, "gammas")?.to_vec1::<f32>()?; + Ok(Self { + norm1, + depthwise_conv, + norm2, + channelwise_lin1, + channelwise_lin2, + gammas, + }) + } +} + +impl Module for MixingResidualBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mods = &self.gammas; + let x_temp = xs + .permute((0, 2, 3, 1))? + .apply(&self.norm1)? + .permute((0, 3, 1, 2))? + .affine(1. + mods[0] as f64, mods[1] as f64)?; + let x_temp = candle_nn::ops::replication_pad2d(&x_temp, 1)?; + let xs = (xs + x_temp.apply(&self.depthwise_conv)? * mods[2] as f64)?; + let x_temp = xs + .permute((0, 2, 3, 1))? + .apply(&self.norm2)? + .permute((0, 3, 1, 2))? + .affine(1. + mods[3] as f64, mods[4] as f64)?; + let x_temp = x_temp + .permute((0, 2, 3, 1))? + .contiguous()? + .apply(&self.channelwise_lin1)? + .gelu()? + .apply(&self.channelwise_lin2)? + .permute((0, 3, 1, 2))?; + xs + x_temp * mods[5] as f64 + } +} + +#[derive(Debug)] +pub struct PaellaVQ { + in_block_conv: candle_nn::Conv2d, + out_block_conv: candle_nn::Conv2d, + down_blocks: Vec<(Option<candle_nn::Conv2d>, MixingResidualBlock)>, + down_blocks_conv: candle_nn::Conv2d, + down_blocks_bn: candle_nn::BatchNorm, + up_blocks_conv: candle_nn::Conv2d, + up_blocks: Vec<(Vec<MixingResidualBlock>, Option<candle_nn::ConvTranspose2d>)>, +} + +impl PaellaVQ { + pub fn new(vb: VarBuilder) -> Result<Self> { + const IN_CHANNELS: usize = 3; + const OUT_CHANNELS: usize = 3; + const LATENT_CHANNELS: usize = 4; + const EMBED_DIM: usize = 384; + const BOTTLENECK_BLOCKS: usize = 12; + const C_LEVELS: [usize; 2] = [EMBED_DIM / 2, EMBED_DIM]; + + let in_block_conv = candle_nn::conv2d( + IN_CHANNELS * 4, + C_LEVELS[0], + 1, + Default::default(), + vb.pp("in_block.1"), + )?; + let out_block_conv = candle_nn::conv2d( + C_LEVELS[0], + OUT_CHANNELS * 4, + 1, + Default::default(), + vb.pp("out_block.0"), + )?; + + let mut down_blocks = Vec::new(); + let vb_d = vb.pp("down_blocks"); + let mut d_idx = 0; + for (i, &c_level) in C_LEVELS.iter().enumerate() { + let conv_block = if i > 0 { + let cfg = candle_nn::Conv2dConfig { + padding: 1, + stride: 2, + ..Default::default() + }; + let block = candle_nn::conv2d(C_LEVELS[i - 1], c_level, 4, cfg, vb_d.pp(d_idx))?; + d_idx += 1; + Some(block) + } else { + None + }; + let res_block = MixingResidualBlock::new(c_level, c_level * 4, vb_d.pp(d_idx))?; + d_idx += 1; + down_blocks.push((conv_block, res_block)) + } + let vb_d = vb_d.pp(d_idx); + let down_blocks_conv = candle_nn::conv2d_no_bias( + C_LEVELS[1], + LATENT_CHANNELS, + 1, + Default::default(), + vb_d.pp(0), + )?; + let down_blocks_bn = candle_nn::batch_norm(LATENT_CHANNELS, 1e-5, vb_d.pp(1))?; + + let mut up_blocks = Vec::new(); + let vb_u = vb.pp("up_blocks"); + let mut u_idx = 0; + let up_blocks_conv = candle_nn::conv2d( + LATENT_CHANNELS, + C_LEVELS[1], + 1, + Default::default(), + vb_u.pp(u_idx).pp(0), + )?; + u_idx += 1; + for (i, &c_level) in C_LEVELS.iter().rev().enumerate() { + let mut res_blocks = Vec::new(); + let n_bottleneck_blocks = if i == 0 { BOTTLENECK_BLOCKS } else { 1 }; + for _j in 0..n_bottleneck_blocks { + let res_block = MixingResidualBlock::new(c_level, c_level * 4, vb_u.pp(u_idx))?; + u_idx += 1; + res_blocks.push(res_block) + } + let conv_block = if i < C_LEVELS.len() - 1 { + let cfg = candle_nn::ConvTranspose2dConfig { + padding: 1, + stride: 2, + ..Default::default() + }; + let block = candle_nn::conv_transpose2d( + c_level, + C_LEVELS[C_LEVELS.len() - i - 2], + 4, + cfg, + vb_u.pp(u_idx), + )?; + u_idx += 1; + Some(block) + } else { + None + }; + up_blocks.push((res_blocks, conv_block)) + } + Ok(Self { + in_block_conv, + down_blocks, + down_blocks_conv, + down_blocks_bn, + up_blocks, + up_blocks_conv, + out_block_conv, + }) + } + + pub fn encode(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = candle_nn::ops::pixel_unshuffle(xs, 2)?.apply(&self.in_block_conv)?; + for down_block in self.down_blocks.iter() { + if let Some(conv) = &down_block.0 { + xs = xs.apply(conv)? + } + xs = xs.apply(&down_block.1)? + } + xs.apply(&self.down_blocks_conv)? + .apply(&self.down_blocks_bn) + } + + pub fn decode(&self, xs: &Tensor) -> Result<Tensor> { + // TODO: quantizer if we want to support `force_not_quantize=False`. + let mut xs = xs.apply(&self.up_blocks_conv)?; + for up_block in self.up_blocks.iter() { + for b in up_block.0.iter() { + xs = xs.apply(b)?; + } + if let Some(conv) = &up_block.1 { + xs = xs.apply(conv)? + } + } + xs.apply(&self.out_block_conv)? + .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, 2)) + } +} + +impl Module for PaellaVQ { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + self.decode(&self.encode(xs)?) + } +} diff --git a/candle-transformers/src/models/wuerstchen/prior.rs b/candle-transformers/src/models/wuerstchen/prior.rs new file mode 100644 index 00000000..97ccf0e2 --- /dev/null +++ b/candle-transformers/src/models/wuerstchen/prior.rs @@ -0,0 +1,103 @@ +use super::common::{AttnBlock, ResBlock, TimestepBlock}; +use candle::{DType, Result, Tensor, D}; +use candle_nn::VarBuilder; + +#[derive(Debug)] +struct Block { + res_block: ResBlock, + ts_block: TimestepBlock, + attn_block: AttnBlock, +} + +#[derive(Debug)] +pub struct WPrior { + projection: candle_nn::Conv2d, + cond_mapper_lin1: candle_nn::Linear, + cond_mapper_lin2: candle_nn::Linear, + blocks: Vec<Block>, + out_ln: super::common::WLayerNorm, + out_conv: candle_nn::Conv2d, + c_r: usize, +} + +impl WPrior { + #[allow(clippy::too_many_arguments)] + pub fn new( + c_in: usize, + c: usize, + c_cond: usize, + c_r: usize, + depth: usize, + nhead: usize, + use_flash_attn: bool, + vb: VarBuilder, + ) -> Result<Self> { + let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?; + let cond_mapper_lin1 = candle_nn::linear(c_cond, c, vb.pp("cond_mapper.0"))?; + let cond_mapper_lin2 = candle_nn::linear(c, c, vb.pp("cond_mapper.2"))?; + let out_ln = super::common::WLayerNorm::new(c)?; + let out_conv = candle_nn::conv2d(c, c_in * 2, 1, Default::default(), vb.pp("out.1"))?; + let mut blocks = Vec::with_capacity(depth); + for index in 0..depth { + let res_block = ResBlock::new(c, 0, 3, vb.pp(format!("blocks.{}", 3 * index)))?; + let ts_block = TimestepBlock::new(c, c_r, vb.pp(format!("blocks.{}", 3 * index + 1)))?; + let attn_block = AttnBlock::new( + c, + c, + nhead, + true, + use_flash_attn, + vb.pp(format!("blocks.{}", 3 * index + 2)), + )?; + blocks.push(Block { + res_block, + ts_block, + attn_block, + }) + } + Ok(Self { + projection, + cond_mapper_lin1, + cond_mapper_lin2, + blocks, + out_ln, + out_conv, + c_r, + }) + } + + pub fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> { + const MAX_POSITIONS: usize = 10000; + let r = (r * MAX_POSITIONS as f64)?; + let half_dim = self.c_r / 2; + let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64; + let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)? + * -emb)? + .exp()?; + let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?; + let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?; + let emb = if self.c_r % 2 == 1 { + emb.pad_with_zeros(D::Minus1, 0, 1)? + } else { + emb + }; + emb.to_dtype(r.dtype()) + } + + pub fn forward(&self, xs: &Tensor, r: &Tensor, c: &Tensor) -> Result<Tensor> { + let x_in = xs; + let mut xs = xs.apply(&self.projection)?; + let c_embed = c + .apply(&self.cond_mapper_lin1)? + .apply(&|xs: &_| candle_nn::ops::leaky_relu(xs, 0.2))? + .apply(&self.cond_mapper_lin2)?; + let r_embed = self.gen_r_embedding(r)?; + for block in self.blocks.iter() { + xs = block.res_block.forward(&xs, None)?; + xs = block.ts_block.forward(&xs, &r_embed)?; + xs = block.attn_block.forward(&xs, &c_embed)?; + } + let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(2, 1)?; + (x_in - &ab[0])? / ((&ab[1] - 1.)?.abs()? + 1e-5) + } +} diff --git a/candle-transformers/src/object_detection.rs b/candle-transformers/src/object_detection.rs new file mode 100644 index 00000000..ce579316 --- /dev/null +++ b/candle-transformers/src/object_detection.rs @@ -0,0 +1,52 @@ +/// A bounding box around an object. +#[derive(Debug, Clone)] +pub struct Bbox<D> { + pub xmin: f32, + pub ymin: f32, + pub xmax: f32, + pub ymax: f32, + pub confidence: f32, + pub data: D, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct KeyPoint { + pub x: f32, + pub y: f32, + pub mask: f32, +} + +/// Intersection over union of two bounding boxes. +pub fn iou<D>(b1: &Bbox<D>, b2: &Bbox<D>) -> f32 { + let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.); + let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.); + let i_xmin = b1.xmin.max(b2.xmin); + let i_xmax = b1.xmax.min(b2.xmax); + let i_ymin = b1.ymin.max(b2.ymin); + let i_ymax = b1.ymax.min(b2.ymax); + let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.); + i_area / (b1_area + b2_area - i_area) +} + +pub fn non_maximum_suppression<D>(bboxes: &mut [Vec<Bbox<D>>], threshold: f32) { + // Perform non-maximum suppression. + for bboxes_for_class in bboxes.iter_mut() { + bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap()); + let mut current_index = 0; + for index in 0..bboxes_for_class.len() { + let mut drop = false; + for prev_index in 0..current_index { + let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]); + if iou > threshold { + drop = true; + break; + } + } + if !drop { + bboxes_for_class.swap(current_index, index); + current_index += 1; + } + } + bboxes_for_class.truncate(current_index); + } +} |