diff options
Diffstat (limited to 'candle-transformers/src')
13 files changed, 1183 insertions, 60 deletions
diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index d6826a16..51c524f5 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -7,8 +7,9 @@ pub const DTYPE: DType = DType::F32; #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] #[serde(rename_all = "lowercase")] -enum HiddenAct { +pub enum HiddenAct { Gelu, + GeluApproximate, Relu, } @@ -28,6 +29,7 @@ impl HiddenActLayer { match self.act { // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213 HiddenAct::Gelu => xs.gelu_erf(), + HiddenAct::GeluApproximate => xs.gelu(), HiddenAct::Relu => xs.relu(), } } @@ -48,7 +50,7 @@ pub struct Config { num_hidden_layers: usize, num_attention_heads: usize, intermediate_size: usize, - hidden_act: HiddenAct, + pub hidden_act: HiddenAct, hidden_dropout_prob: f64, max_position_embeddings: usize, type_vocab_size: usize, diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index e822ca14..b0e2fb88 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -57,6 +57,22 @@ impl Config { } } + pub fn v2() -> Self { + Self { + vocab_size: 51200, + n_positions: 2048, + n_embd: 2560, + n_layer: 32, + n_inner: None, + n_head: 32, + rotary_dim: usize::min(32, 2560 / 32), + activation_function: Activation::Gelu, + layer_norm_epsilon: 1e-5, + tie_word_embeddings: false, + pad_vocab_size_multiple: 64, + } + } + // https://huggingface.co/teknium/Puffin-Phi-v2/blob/main/config.json pub fn puffin_phi_v2() -> Self { Self { @@ -372,6 +388,24 @@ pub struct MixFormerSequentialForCausalLM { } impl MixFormerSequentialForCausalLM { + pub fn new_v2(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let vb_head = vb.pp("lm_head"); + let vb = vb.pp("transformer"); + let embedding = Embedding::new(cfg, vb.pp("embd"))?; + let mut blocks = Vec::new(); + for i in 0..cfg.n_layer { + let block = ParallelBlock::new(cfg, vb.pp("h").pp(i))?; + blocks.push(block) + } + let head = CausalLMHead::new(cfg, vb_head)?; + Ok(Self { + embedding, + blocks, + head, + span: tracing::span!(tracing::Level::TRACE, "mixformer"), + }) + } + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { let vb = vb.pp("layers"); let embedding = Embedding::new(cfg, vb.pp(0))?; diff --git a/candle-transformers/src/models/mixtral.rs b/candle-transformers/src/models/mixtral.rs new file mode 100644 index 00000000..ede74d3f --- /dev/null +++ b/candle-transformers/src/models/mixtral.rs @@ -0,0 +1,499 @@ +use crate::models::with_tracing::{linear_no_bias, Linear}; +/// Mixtral Model +/// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py +/// https://mistral.ai/news/mixtral-of-experts/ +use candle::{DType, Device, Module, Result, Tensor, D}; +use candle_nn::{Activation, VarBuilder}; +use serde::Deserialize; +use std::sync::Arc; + +/// https://github.com/huggingface/transformers/blob/1a585c1222a56bcaecc070966d558d4a9d862e83/src/transformers/models/mixtral/configuration_mixtral.py#L113 +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + pub(crate) vocab_size: usize, + pub(crate) hidden_size: usize, + pub(crate) intermediate_size: usize, + pub(crate) num_hidden_layers: usize, + pub(crate) num_attention_heads: usize, + pub(crate) num_key_value_heads: usize, + pub(crate) hidden_act: Activation, + pub(crate) max_position_embeddings: usize, + pub(crate) rms_norm_eps: f64, + pub(crate) rope_theta: f64, + pub(crate) sliding_window: usize, + pub(crate) num_experts_per_tok: usize, + pub(crate) num_local_experts: usize, + pub(crate) use_flash_attn: bool, +} + +impl Config { + /// https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json + pub fn v0_1_8x7b(use_flash_attn: bool) -> Self { + Self { + vocab_size: 32000, + hidden_size: 4096, + intermediate_size: 14336, + num_hidden_layers: 32, + num_attention_heads: 32, + num_key_value_heads: 8, + hidden_act: Activation::Silu, + max_position_embeddings: 32768, + rms_norm_eps: 1e-5, + rope_theta: 1e6, + sliding_window: 4096, + num_experts_per_tok: 2, + num_local_experts: 8, + use_flash_attn, + } + } +} + +#[derive(Debug, Clone)] +struct RmsNorm { + inner: candle_nn::RmsNorm, + span: tracing::Span, +} + +impl RmsNorm { + fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); + let inner = candle_nn::rms_norm(size, eps, vb)?; + Ok(Self { inner, span }) + } +} + +impl Module for RmsNorm { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +fn rotate_half(xs: &Tensor) -> Result<Tensor> { + let last_dim = xs.dim(D::Minus1)?; + let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; + let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; + Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) +} + +impl RotaryEmbedding { + fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> { + let dim = cfg.hidden_size / cfg.num_attention_heads; + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / (cfg.rope_theta as f32).powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + fn apply_rotary_emb_qkv( + &self, + q: &Tensor, + k: &Tensor, + seqlen_offset: usize, + ) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; + let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; + let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) + let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) + let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; + let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; + Ok((q_embed, k_embed)) + } +} + +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result<Tensor> { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> { + unimplemented!("compile with '--features flash-attn'") +} + +#[derive(Debug, Clone)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + hidden_size: usize, + rotary_emb: Arc<RotaryEmbedding>, + kv_cache: Option<(Tensor, Tensor)>, + use_flash_attn: bool, +} + +impl Attention { + fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> { + let hidden_sz = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let num_kv_groups = num_heads / num_kv_heads; + let head_dim = hidden_sz / num_heads; + let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; + let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; + let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; + let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_heads, + num_kv_heads, + num_kv_groups, + head_dim, + hidden_size: hidden_sz, + rotary_emb, + kv_cache: None, + use_flash_attn: cfg.use_flash_attn, + }) + } + + fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> { + let n_rep = self.num_kv_groups; + if n_rep == 1 { + Ok(xs) + } else { + let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?; + xs.unsqueeze(2)? + .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))? + .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim)) + } + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result<Tensor> { + let (b_sz, q_len, _) = xs.dims3()?; + + let query_states = self.q_proj.forward(xs)?; + let key_states = self.k_proj.forward(xs)?; + let value_states = self.v_proj.forward(xs)?; + + let query_states = query_states + .reshape((b_sz, q_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let key_states = key_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let value_states = value_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + let (query_states, key_states) = + self.rotary_emb + .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?; + + let (key_states, value_states) = match &self.kv_cache { + None => (key_states, value_states), + Some((prev_k, prev_v)) => { + let key_states = Tensor::cat(&[prev_k, &key_states], 2)?; + let value_states = Tensor::cat(&[prev_v, &value_states], 2)?; + (key_states, value_states) + } + }; + self.kv_cache = Some((key_states.clone(), value_states.clone())); + + let key_states = self.repeat_kv(key_states)?; + let value_states = self.repeat_kv(value_states)?; + + let attn_output = if self.use_flash_attn { + // flash-attn expects (b_sz, seq_len, nheads, head_dim) + let q = query_states.transpose(1, 2)?; + let k = key_states.transpose(1, 2)?; + let v = value_states.transpose(1, 2)?; + let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); + flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)? + } else { + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; + + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights.matmul(&value_states)? + }; + attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, self.hidden_size))? + .apply(&self.o_proj) + } +} + +#[derive(Debug, Clone)] +struct BlockSparseTop2MLP { + w1: Linear, + w2: Linear, + w3: Linear, + act_fn: Activation, +} + +impl BlockSparseTop2MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let hidden_sz = cfg.hidden_size; + let intermediate_sz = cfg.intermediate_size; + let w1 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w1"))?; + let w2 = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("w2"))?; + let w3 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w3"))?; + Ok(Self { + w1, + w2, + w3, + act_fn: cfg.hidden_act, + }) + } +} + +impl Module for BlockSparseTop2MLP { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let lhs = xs.apply(&self.w1)?.apply(&self.act_fn)?; + let rhs = xs.apply(&self.w3)?; + (lhs * rhs)?.apply(&self.w2) + } +} + +#[derive(Debug, Clone)] +struct SparseMoeBlock { + gate: Linear, + experts: Vec<BlockSparseTop2MLP>, + num_experts_per_tok: usize, +} + +impl SparseMoeBlock { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let gate = linear_no_bias(cfg.hidden_size, cfg.num_local_experts, vb.pp("gate"))?; + let mut experts = Vec::with_capacity(cfg.num_local_experts); + let vb = vb.pp("experts"); + for idx in 0..cfg.num_local_experts { + let expert = BlockSparseTop2MLP::new(cfg, vb.pp(idx))?; + experts.push(expert) + } + Ok(SparseMoeBlock { + gate, + experts, + num_experts_per_tok: cfg.num_experts_per_tok, + }) + } +} + +impl Module for SparseMoeBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let (b_size, seq_len, hidden_dim) = xs.dims3()?; + let xs = xs.reshape(((), hidden_dim))?; + let router_logits = xs.apply(&self.gate)?; + let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?; + + // In order to extract topk, we extract the data from the tensor and manipulate it + // directly. Maybe we will want to use some custom ops instead at some point. + let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?; + + // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + // top_x contains the row indexes to evaluate for each expert. + let mut top_x = vec![vec![]; self.experts.len()]; + let mut selected_rws = vec![vec![]; self.experts.len()]; + for (row_idx, rw) in routing_weights.iter().enumerate() { + let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>(); + dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize])); + let mut sum_routing_weights = 0f32; + for &expert_idx in dst.iter().take(self.num_experts_per_tok) { + let expert_idx = expert_idx as usize; + let routing_weight = rw[expert_idx]; + sum_routing_weights += routing_weight; + top_x[expert_idx].push(row_idx as u32); + } + for &expert_idx in dst.iter().take(self.num_experts_per_tok) { + let expert_idx = expert_idx as usize; + let routing_weight = rw[expert_idx]; + selected_rws[expert_idx].push(routing_weight / sum_routing_weights) + } + } + + // routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + let mut ys = xs.zeros_like()?; + for (expert_idx, expert_layer) in self.experts.iter().enumerate() { + let top_x = &top_x[expert_idx]; + if top_x.is_empty() { + continue; + } + let top_x = Tensor::new(top_x.as_slice(), xs.device())?; + let selected_rws = + Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?; + // Index the correct hidden states and compute the expert hidden state for + // the current expert. We need to make sure to multiply the output hidden + // states by `routing_weights` on the corresponding tokens (top-1 and top-2) + let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?; + // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None]) + let current_hidden_states = expert_layer.forward(¤t_state)?; + let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?; + ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?; + } + + let ys = ys.reshape((b_size, seq_len, hidden_dim))?; + Ok(ys) + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: Attention, + block_sparse_moe: SparseMoeBlock, + input_layernorm: RmsNorm, + post_attention_layernorm: RmsNorm, +} + +impl DecoderLayer { + fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> { + let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; + let block_sparse_moe = SparseMoeBlock::new(cfg, vb.pp("block_sparse_moe"))?; + let input_layernorm = + RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let post_attention_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + self_attn, + block_sparse_moe, + input_layernorm, + post_attention_layernorm, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result<Tensor> { + let residual = xs; + let xs = self.input_layernorm.forward(xs)?; + let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = xs + .apply(&self.post_attention_layernorm)? + .apply(&self.block_sparse_moe)?; + residual + xs + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embed_tokens: candle_nn::Embedding, + layers: Vec<DecoderLayer>, + norm: RmsNorm, + lm_head: Linear, + sliding_window: usize, + device: Device, + dtype: DType, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let vb_m = vb.pp("model"); + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb_m.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; + let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + Ok(Self { + embed_tokens, + layers, + norm, + lm_head, + sliding_window: cfg.sliding_window, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + fn prepare_decoder_attention_mask( + &self, + b_size: usize, + tgt_len: usize, + seqlen_offset: usize, + ) -> Result<Tensor> { + // Sliding window mask? + let mask: Vec<_> = (0..tgt_len) + .flat_map(|i| { + (0..tgt_len).map(move |j| { + if i < j || j + self.sliding_window < i { + f32::NEG_INFINITY + } else { + 0. + } + }) + }) + .collect(); + let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; + let mask = if seqlen_offset > 0 { + let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?; + Tensor::cat(&[&mask0, &mask], D::Minus1)? + } else { + mask + }; + mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + .to_dtype(self.dtype) + } + + pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> { + let (b_size, seq_len) = input_ids.dims2()?; + let attention_mask = if seq_len <= 1 { + None + } else { + let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?; + Some(mask) + }; + let mut xs = self.embed_tokens.forward(input_ids)?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)? + } + xs.narrow(1, seq_len - 1, 1)? + .apply(&self.norm)? + .apply(&self.lm_head) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index a9a56673..94a3bd5b 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -14,6 +14,7 @@ pub mod llama2_c_weights; pub mod marian; pub mod mistral; pub mod mixformer; +pub mod mixtral; pub mod mpt; pub mod persimmon; pub mod quantized_blip; diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 44d89f40..1fb2d9e2 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -48,15 +48,109 @@ impl QMatMul { } #[derive(Debug, Clone)] +struct Mlp { + feed_forward_w1: QMatMul, + feed_forward_w2: QMatMul, + feed_forward_w3: QMatMul, +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let w1 = self.feed_forward_w1.forward(xs)?; + let w3 = self.feed_forward_w3.forward(xs)?; + self.feed_forward_w2 + .forward(&(candle_nn::ops::silu(&w1)? * w3)?) + } +} + +#[derive(Debug, Clone)] +enum MlpOrMoe { + Mlp(Mlp), + MoE { + n_expert_used: usize, + feed_forward_gate_inp: QMatMul, + experts: Vec<Mlp>, + }, +} + +impl Module for MlpOrMoe { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + match self { + Self::MoE { + feed_forward_gate_inp, + experts, + n_expert_used, + } => { + let (b_size, seq_len, hidden_dim) = xs.dims3()?; + let xs = xs.reshape(((), hidden_dim))?; + let router_logits = feed_forward_gate_inp.forward(&xs)?; + let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?; + + // In order to extract topk, we extract the data from the tensor and manipulate it + // directly. Maybe we will want to use some custom ops instead at some point. + let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?; + + // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + // top_x contains the row indexes to evaluate for each expert. + let mut top_x = vec![vec![]; experts.len()]; + let mut selected_rws = vec![vec![]; experts.len()]; + for (row_idx, rw) in routing_weights.iter().enumerate() { + let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>(); + dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize])); + let mut sum_routing_weights = 0f32; + for &expert_idx in dst.iter().take(*n_expert_used) { + let expert_idx = expert_idx as usize; + let routing_weight = rw[expert_idx]; + sum_routing_weights += routing_weight; + top_x[expert_idx].push(row_idx as u32); + } + for &expert_idx in dst.iter().take(*n_expert_used) { + let expert_idx = expert_idx as usize; + let routing_weight = rw[expert_idx]; + selected_rws[expert_idx].push(routing_weight / sum_routing_weights) + } + } + + // routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + let mut ys = xs.zeros_like()?; + for (expert_idx, expert_layer) in experts.iter().enumerate() { + let top_x = &top_x[expert_idx]; + if top_x.is_empty() { + continue; + } + let top_x = Tensor::new(top_x.as_slice(), xs.device())?; + let selected_rws = + Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())? + .reshape(((), 1))?; + // Index the correct hidden states and compute the expert hidden state for + // the current expert. We need to make sure to multiply the output hidden + // states by `routing_weights` on the corresponding tokens (top-1 and top-2) + let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?; + // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None]) + let current_hidden_states = expert_layer.forward(¤t_state)?; + let current_hidden_states = + current_hidden_states.broadcast_mul(&selected_rws)?; + ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?; + } + + let ys = ys.reshape((b_size, seq_len, hidden_dim))?; + Ok(ys) + } + Self::Mlp(mlp) => mlp.forward(xs), + } + } +} + +#[derive(Debug, Clone)] struct LayerWeights { attention_wq: QMatMul, attention_wk: QMatMul, attention_wv: QMatMul, attention_wo: QMatMul, attention_norm: RmsNorm, - feed_forward_w1: QMatMul, - feed_forward_w2: QMatMul, - feed_forward_w3: QMatMul, + mlp_or_moe: MlpOrMoe, ffn_norm: RmsNorm, n_head: usize, n_kv_head: usize, @@ -212,9 +306,16 @@ impl ModelWeights { let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?; let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?; let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?; - let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; - let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; - let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; + let mlp_or_moe = { + let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; + let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; + let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; + MlpOrMoe::Mlp(Mlp { + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + }) + }; let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?; let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?; let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); @@ -226,9 +327,7 @@ impl ModelWeights { attention_wv: QMatMul::from_qtensor(attention_wv)?, attention_wo: QMatMul::from_qtensor(attention_wo)?, attention_norm: RmsNorm::new(attention_norm, 1e-5)?, - feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, - feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, - feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + mlp_or_moe, ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?, n_head: ct.hparams.n_head as usize, n_kv_head: ct.hparams.n_head as usize / gqa, @@ -265,6 +364,12 @@ impl ModelWeights { }; // Parameter extraction from metadata. + let n_expert = md_get("llama.expert_count") + .and_then(|v| v.to_u32()) + .unwrap_or(0) as usize; + let n_expert_used = md_get("llama.expert_used_count") + .and_then(|v| v.to_u32()) + .unwrap_or(0) as usize; let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; let block_count = md_get("llama.block_count")?.to_u32()? as usize; @@ -289,9 +394,38 @@ impl ModelWeights { let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?; let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?; let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?; - let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?; - let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?; - let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?; + let mlp_or_moe = if n_expert <= 1 { + let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?; + let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?; + let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?; + MlpOrMoe::Mlp(Mlp { + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + }) + } else { + let feed_forward_gate_inp = + ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"))?; + let mut experts = Vec::with_capacity(n_expert); + for i in 0..n_expert { + let feed_forward_w1 = + ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"))?; + let feed_forward_w2 = + ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"))?; + let feed_forward_w3 = + ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"))?; + experts.push(Mlp { + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + }) + } + MlpOrMoe::MoE { + n_expert_used, + feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?, + experts, + } + }; let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?; let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?; let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); @@ -303,9 +437,7 @@ impl ModelWeights { attention_wv: QMatMul::from_qtensor(attention_wv)?, attention_wo: QMatMul::from_qtensor(attention_wo)?, attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?, - feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, - feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, - feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + mlp_or_moe, ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?, n_head: head_count, n_kv_head: head_count_kv, @@ -360,12 +492,9 @@ impl ModelWeights { let _enter = layer.span_mlp.enter(); let residual = &x; let x = layer.ffn_norm.forward(&x)?; - let w1 = layer.feed_forward_w1.forward(&x)?; - let w3 = layer.feed_forward_w3.forward(&x)?; - let mlp = layer - .feed_forward_w2 - .forward(&(candle_nn::ops::silu(&w1)? * w3)?)?; - layer_in = (mlp + residual)?; + let x = layer.mlp_or_moe.forward(&x)?; + let x = (x + residual)?; + layer_in = x } let x = self.norm.forward(&layer_in)?; let x = x.i((.., seq_len - 1, ..))?; diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs index f11f2036..1a3cd4ac 100644 --- a/candle-transformers/src/models/quantized_mixformer.rs +++ b/candle-transformers/src/models/quantized_mixformer.rs @@ -287,6 +287,24 @@ pub struct MixFormerSequentialForCausalLM { } impl MixFormerSequentialForCausalLM { + pub fn new_v2(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let vb_head = vb.pp("lm_head"); + let vb = vb.pp("transformer"); + let embedding = Embedding::new(cfg, vb.pp("embd"))?; + let mut blocks = Vec::new(); + for i in 0..cfg.n_layer { + let block = ParallelBlock::new(cfg, vb.pp("h").pp(i))?; + blocks.push(block) + } + let head = CausalLMHead::new(cfg, vb_head)?; + Ok(Self { + embedding, + blocks, + head, + span: tracing::span!(tracing::Level::TRACE, "mixformer"), + }) + } + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { let vb = vb.pp("layers"); let embedding = Embedding::new(cfg, vb.pp(0))?; diff --git a/candle-transformers/src/models/segment_anything/mask_decoder.rs b/candle-transformers/src/models/segment_anything/mask_decoder.rs index 2a91cd44..1703c809 100644 --- a/candle-transformers/src/models/segment_anything/mask_decoder.rs +++ b/candle-transformers/src/models/segment_anything/mask_decoder.rs @@ -182,7 +182,7 @@ impl MaskDecoder { sparse_prompt_embeddings: &Tensor, dense_prompt_embeddings: &Tensor, ) -> Result<(Tensor, Tensor)> { - // Concatenate ouput tokens. + // Concatenate output tokens. let output_tokens = Tensor::cat( &[self.iou_token.embeddings(), self.mask_tokens.embeddings()], 0, diff --git a/candle-transformers/src/models/segment_anything/prompt_encoder.rs b/candle-transformers/src/models/segment_anything/prompt_encoder.rs index 9d0074b1..16e8a4e8 100644 --- a/candle-transformers/src/models/segment_anything/prompt_encoder.rs +++ b/candle-transformers/src/models/segment_anything/prompt_encoder.rs @@ -2,11 +2,11 @@ use candle::{DType, IndexOp, Result, Tensor, D}; use candle_nn::VarBuilder; #[derive(Debug)] -struct PostionEmbeddingRandom { +struct PositionEmbeddingRandom { positional_encoding_gaussian_matrix: Tensor, } -impl PostionEmbeddingRandom { +impl PositionEmbeddingRandom { fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> { let positional_encoding_gaussian_matrix = vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?; @@ -52,7 +52,7 @@ impl PostionEmbeddingRandom { #[derive(Debug)] pub struct PromptEncoder { - pe_layer: PostionEmbeddingRandom, + pe_layer: PositionEmbeddingRandom, point_embeddings: Vec<candle_nn::Embedding>, not_a_point_embed: candle_nn::Embedding, mask_downscaling_conv1: candle_nn::Conv2d, @@ -76,7 +76,7 @@ impl PromptEncoder { vb: VarBuilder, ) -> Result<Self> { let num_points_embeddings = 4; - let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?; + let pe_layer = PositionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?; let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?; let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?; let cfg = candle_nn::Conv2dConfig { diff --git a/candle-transformers/src/models/stable_diffusion/ddim.rs b/candle-transformers/src/models/stable_diffusion/ddim.rs index 916b7349..d804ed56 100644 --- a/candle-transformers/src/models/stable_diffusion/ddim.rs +++ b/candle-transformers/src/models/stable_diffusion/ddim.rs @@ -7,7 +7,9 @@ //! //! Denoising Diffusion Implicit Models, J. Song et al, 2020. //! https://arxiv.org/abs/2010.02502 -use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType}; +use super::schedulers::{ + betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig, TimestepSpacing, +}; use candle::{Result, Tensor}; /// The configuration for the DDIM scheduler. @@ -29,6 +31,8 @@ pub struct DDIMSchedulerConfig { pub prediction_type: PredictionType, /// number of diffusion steps used to train the model pub train_timesteps: usize, + /// time step spacing for the diffusion process + pub timestep_spacing: TimestepSpacing, } impl Default for DDIMSchedulerConfig { @@ -41,10 +45,17 @@ impl Default for DDIMSchedulerConfig { steps_offset: 1, prediction_type: PredictionType::Epsilon, train_timesteps: 1000, + timestep_spacing: TimestepSpacing::Leading, } } } +impl SchedulerConfig for DDIMSchedulerConfig { + fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> { + Ok(Box::new(DDIMScheduler::new(inference_steps, *self)?)) + } +} + /// The DDIM scheduler. #[derive(Debug, Clone)] pub struct DDIMScheduler { @@ -60,12 +71,32 @@ impl DDIMScheduler { /// Creates a new DDIM scheduler given the number of steps to be /// used for inference as well as the number of steps that was used /// during training. - pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> { + fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> { let step_ratio = config.train_timesteps / inference_steps; - let timesteps: Vec<usize> = (0..(inference_steps)) - .map(|s| s * step_ratio + config.steps_offset) - .rev() - .collect(); + let timesteps: Vec<usize> = match config.timestep_spacing { + TimestepSpacing::Leading => (0..(inference_steps)) + .map(|s| s * step_ratio + config.steps_offset) + .rev() + .collect(), + TimestepSpacing::Trailing => std::iter::successors(Some(config.train_timesteps), |n| { + if *n > step_ratio { + Some(n - step_ratio) + } else { + None + } + }) + .map(|n| n - 1) + .collect(), + TimestepSpacing::Linspace => { + super::utils::linspace(0.0, (config.train_timesteps - 1) as f64, inference_steps)? + .to_vec1::<f64>()? + .iter() + .map(|&f| f as usize) + .rev() + .collect() + } + }; + let betas = match config.beta_schedule { BetaSchedule::ScaledLinear => super::utils::linspace( config.beta_start.sqrt(), @@ -92,19 +123,11 @@ impl DDIMScheduler { config, }) } +} - pub fn timesteps(&self) -> &[usize] { - self.timesteps.as_slice() - } - - /// Ensures interchangeability with schedulers that need to scale the denoising model input - /// depending on the current timestep. - pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> { - Ok(sample) - } - +impl Scheduler for DDIMScheduler { /// Performs a backward step during inference. - pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> { + fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> { let timestep = if timestep >= self.alphas_cumprod.len() { timestep - 1 } else { @@ -163,7 +186,17 @@ impl DDIMScheduler { } } - pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> { + /// Ensures interchangeability with schedulers that need to scale the denoising model input + /// depending on the current timestep. + fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> { + Ok(sample) + } + + fn timesteps(&self) -> &[usize] { + self.timesteps.as_slice() + } + + fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> { let timestep = if timestep >= self.alphas_cumprod.len() { timestep - 1 } else { @@ -174,7 +207,7 @@ impl DDIMScheduler { (original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)? } - pub fn init_noise_sigma(&self) -> f64 { + fn init_noise_sigma(&self) -> f64 { self.init_noise_sigma } } diff --git a/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs new file mode 100644 index 00000000..9576c2de --- /dev/null +++ b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs @@ -0,0 +1,235 @@ +//! Ancestral sampling with Euler method steps. +//! +//! Reference implementation in Rust: +//! +//! https://github.com/pykeio/diffusers/blob/250b9ad1898af41e76a74c0d8d4292652823338a/src/schedulers/euler_ancestral_discrete.rs +//! +//! Based on the original [`k-diffusion` implementation by Katherine Crowson][kd]. +/// +/// [kd]: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 +use super::{ + schedulers::{ + betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig, + TimestepSpacing, + }, + utils::interp, +}; +use candle::{bail, Error, Result, Tensor}; + +/// The configuration for the EulerAncestral Discrete scheduler. +#[derive(Debug, Clone, Copy)] +pub struct EulerAncestralDiscreteSchedulerConfig { + /// The value of beta at the beginning of training.n + pub beta_start: f64, + /// The value of beta at the end of training. + pub beta_end: f64, + /// How beta evolved during training. + pub beta_schedule: BetaSchedule, + /// Adjust the indexes of the inference schedule by this value. + pub steps_offset: usize, + /// prediction type of the scheduler function, one of `epsilon` (predicting + /// the noise of the diffusion process), `sample` (directly predicting the noisy sample`) + /// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + pub prediction_type: PredictionType, + /// number of diffusion steps used to train the model + pub train_timesteps: usize, + /// time step spacing for the diffusion process + pub timestep_spacing: TimestepSpacing, +} + +impl Default for EulerAncestralDiscreteSchedulerConfig { + fn default() -> Self { + Self { + beta_start: 0.00085f64, + beta_end: 0.012f64, + beta_schedule: BetaSchedule::ScaledLinear, + steps_offset: 1, + prediction_type: PredictionType::Epsilon, + train_timesteps: 1000, + timestep_spacing: TimestepSpacing::Leading, + } + } +} + +impl SchedulerConfig for EulerAncestralDiscreteSchedulerConfig { + fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> { + Ok(Box::new(EulerAncestralDiscreteScheduler::new( + inference_steps, + *self, + )?)) + } +} + +/// The EulerAncestral Discrete scheduler. +#[derive(Debug, Clone)] +pub struct EulerAncestralDiscreteScheduler { + timesteps: Vec<usize>, + sigmas: Vec<f64>, + init_noise_sigma: f64, + pub config: EulerAncestralDiscreteSchedulerConfig, +} + +// clip_sample: False, set_alpha_to_one: False +impl EulerAncestralDiscreteScheduler { + /// Creates a new EulerAncestral Discrete scheduler given the number of steps to be + /// used for inference as well as the number of steps that was used + /// during training. + pub fn new( + inference_steps: usize, + config: EulerAncestralDiscreteSchedulerConfig, + ) -> Result<Self> { + let step_ratio = config.train_timesteps / inference_steps; + let timesteps: Vec<usize> = match config.timestep_spacing { + TimestepSpacing::Leading => (0..(inference_steps)) + .map(|s| s * step_ratio + config.steps_offset) + .rev() + .collect(), + TimestepSpacing::Trailing => std::iter::successors(Some(config.train_timesteps), |n| { + if *n > step_ratio { + Some(n - step_ratio) + } else { + None + } + }) + .map(|n| n - 1) + .collect(), + TimestepSpacing::Linspace => { + super::utils::linspace(0.0, (config.train_timesteps - 1) as f64, inference_steps)? + .to_vec1::<f64>()? + .iter() + .map(|&f| f as usize) + .rev() + .collect() + } + }; + + let betas = match config.beta_schedule { + BetaSchedule::ScaledLinear => super::utils::linspace( + config.beta_start.sqrt(), + config.beta_end.sqrt(), + config.train_timesteps, + )? + .sqr()?, + BetaSchedule::Linear => { + super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)? + } + BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?, + }; + let betas = betas.to_vec1::<f64>()?; + let mut alphas_cumprod = Vec::with_capacity(betas.len()); + for &beta in betas.iter() { + let alpha = 1.0 - beta; + alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64)) + } + let sigmas: Vec<f64> = alphas_cumprod + .iter() + .map(|&f| ((1. - f) / f).sqrt()) + .collect(); + + let sigmas_xa: Vec<_> = (0..sigmas.len()).map(|i| i as f64).collect(); + + let mut sigmas_int = interp( + ×teps.iter().map(|&t| t as f64).collect::<Vec<_>>(), + &sigmas_xa, + &sigmas, + ); + sigmas_int.push(0.0); + + // standard deviation of the initial noise distribution + // f64 does not implement Ord such that there is no `max`, so we need to use this workaround + let init_noise_sigma = *sigmas_int + .iter() + .chain(std::iter::once(&0.0)) + .reduce(|a, b| if a > b { a } else { b }) + .expect("init_noise_sigma could not be reduced from sigmas - this should never happen"); + + Ok(Self { + sigmas: sigmas_int, + timesteps, + init_noise_sigma, + config, + }) + } +} + +impl Scheduler for EulerAncestralDiscreteScheduler { + fn timesteps(&self) -> &[usize] { + self.timesteps.as_slice() + } + + /// Ensures interchangeability with schedulers that need to scale the denoising model input + /// depending on the current timestep. + /// + /// Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm + fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result<Tensor> { + let step_index = match self.timesteps.iter().position(|&t| t == timestep) { + Some(i) => i, + None => bail!("timestep out of this schedulers bounds: {timestep}"), + }; + + let sigma = self + .sigmas + .get(step_index) + .expect("step_index out of sigma bounds - this shouldn't happen"); + + sample / ((sigma.powi(2) + 1.).sqrt()) + } + + /// Performs a backward step during inference. + fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> { + let step_index = self + .timesteps + .iter() + .position(|&p| p == timestep) + .ok_or_else(|| Error::Msg("timestep out of this schedulers bounds".to_string()))?; + + let sigma_from = &self.sigmas[step_index]; + let sigma_to = &self.sigmas[step_index + 1]; + + // 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + let pred_original_sample = match self.config.prediction_type { + PredictionType::Epsilon => (sample - (model_output * *sigma_from))?, + PredictionType::VPrediction => { + ((model_output * (-sigma_from / (sigma_from.powi(2) + 1.0).sqrt()))? + + (sample / (sigma_from.powi(2) + 1.0))?)? + } + PredictionType::Sample => bail!("prediction_type not implemented yet: sample"), + }; + + let sigma_up = (sigma_to.powi(2) * (sigma_from.powi(2) - sigma_to.powi(2)) + / sigma_from.powi(2)) + .sqrt(); + let sigma_down = (sigma_to.powi(2) - sigma_up.powi(2)).sqrt(); + + // 2. convert to a ODE derivative + let derivative = ((sample - pred_original_sample)? / *sigma_from)?; + let dt = sigma_down - *sigma_from; + let prev_sample = (sample + derivative * dt)?; + + let noise = prev_sample.randn_like(0.0, 1.0)?; + + prev_sample + noise * sigma_up + } + + fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> { + let step_index = self + .timesteps + .iter() + .position(|&p| p == timestep) + .ok_or_else(|| Error::Msg("timestep out of this schedulers bounds".to_string()))?; + + let sigma = self + .sigmas + .get(step_index) + .expect("step_index out of sigma bounds - this shouldn't happen"); + + original + (noise * *sigma)? + } + + fn init_noise_sigma(&self) -> f64 { + match self.config.timestep_spacing { + TimestepSpacing::Trailing | TimestepSpacing::Linspace => self.init_noise_sigma, + TimestepSpacing::Leading => (self.init_noise_sigma.powi(2) + 1.0).sqrt(), + } + } +} diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs index 66ef7149..30f23975 100644 --- a/candle-transformers/src/models/stable_diffusion/mod.rs +++ b/candle-transformers/src/models/stable_diffusion/mod.rs @@ -3,6 +3,7 @@ pub mod clip; pub mod ddim; pub mod ddpm; pub mod embeddings; +pub mod euler_ancestral_discrete; pub mod resnet; pub mod schedulers; pub mod unet_2d; @@ -10,9 +11,13 @@ pub mod unet_2d_blocks; pub mod utils; pub mod vae; +use std::sync::Arc; + use candle::{DType, Device, Result}; use candle_nn as nn; +use self::schedulers::{Scheduler, SchedulerConfig}; + #[derive(Clone, Debug)] pub struct StableDiffusionConfig { pub width: usize, @@ -21,7 +26,7 @@ pub struct StableDiffusionConfig { pub clip2: Option<clip::Config>, autoencoder: vae::AutoEncoderKLConfig, unet: unet_2d::UNet2DConditionModelConfig, - scheduler: ddim::DDIMSchedulerConfig, + scheduler: Arc<dyn SchedulerConfig>, } impl StableDiffusionConfig { @@ -75,13 +80,18 @@ impl StableDiffusionConfig { 512 }; - Self { + let scheduler = Arc::new(ddim::DDIMSchedulerConfig { + prediction_type: schedulers::PredictionType::Epsilon, + ..Default::default() + }); + + StableDiffusionConfig { width, height, clip: clip::Config::v1_5(), clip2: None, autoencoder, - scheduler: Default::default(), + scheduler, unet, } } @@ -124,10 +134,10 @@ impl StableDiffusionConfig { latent_channels: 4, norm_num_groups: 32, }; - let scheduler = ddim::DDIMSchedulerConfig { + let scheduler = Arc::new(ddim::DDIMSchedulerConfig { prediction_type, ..Default::default() - }; + }); let height = if let Some(height) = height { assert_eq!(height % 8, 0, "height has to be divisible by 8"); @@ -143,7 +153,7 @@ impl StableDiffusionConfig { 768 }; - Self { + StableDiffusionConfig { width, height, clip: clip::Config::v2_1(), @@ -205,10 +215,10 @@ impl StableDiffusionConfig { latent_channels: 4, norm_num_groups: 32, }; - let scheduler = ddim::DDIMSchedulerConfig { + let scheduler = Arc::new(ddim::DDIMSchedulerConfig { prediction_type, ..Default::default() - }; + }); let height = if let Some(height) = height { assert_eq!(height % 8, 0, "height has to be divisible by 8"); @@ -224,6 +234,76 @@ impl StableDiffusionConfig { 1024 }; + StableDiffusionConfig { + width, + height, + clip: clip::Config::sdxl(), + clip2: Some(clip::Config::sdxl2()), + autoencoder, + scheduler, + unet, + } + } + + fn sdxl_turbo_( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + prediction_type: schedulers::PredictionType, + ) -> Self { + let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig { + out_channels, + use_cross_attn, + attention_head_dim, + }; + // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/unet/config.json + let unet = unet_2d::UNet2DConditionModelConfig { + blocks: vec![ + bc(320, None, 5), + bc(640, Some(2), 10), + bc(1280, Some(10), 20), + ], + center_input_sample: false, + cross_attention_dim: 2048, + downsample_padding: 1, + flip_sin_to_cos: true, + freq_shift: 0., + layers_per_block: 2, + mid_block_scale_factor: 1., + norm_eps: 1e-5, + norm_num_groups: 32, + sliced_attention_size, + use_linear_projection: true, + }; + // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/vae/config.json + let autoencoder = vae::AutoEncoderKLConfig { + block_out_channels: vec![128, 256, 512, 512], + layers_per_block: 2, + latent_channels: 4, + norm_num_groups: 32, + }; + let scheduler = Arc::new( + euler_ancestral_discrete::EulerAncestralDiscreteSchedulerConfig { + prediction_type, + timestep_spacing: schedulers::TimestepSpacing::Trailing, + ..Default::default() + }, + ); + + let height = if let Some(height) = height { + assert_eq!(height % 8, 0, "height has to be divisible by 8"); + height + } else { + 512 + }; + + let width = if let Some(width) = width { + assert_eq!(width % 8, 0, "width has to be divisible by 8"); + width + } else { + 512 + }; + Self { width, height, @@ -249,6 +329,20 @@ impl StableDiffusionConfig { ) } + pub fn sdxl_turbo( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + ) -> Self { + Self::sdxl_turbo_( + sliced_attention_size, + height, + width, + // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/scheduler/scheduler_config.json + schedulers::PredictionType::Epsilon, + ) + } + pub fn ssd1b( sliced_attention_size: Option<usize>, height: Option<usize>, @@ -285,9 +379,9 @@ impl StableDiffusionConfig { latent_channels: 4, norm_num_groups: 32, }; - let scheduler = ddim::DDIMSchedulerConfig { + let scheduler = Arc::new(ddim::DDIMSchedulerConfig { ..Default::default() - }; + }); let height = if let Some(height) = height { assert_eq!(height % 8, 0, "height has to be divisible by 8"); @@ -347,8 +441,8 @@ impl StableDiffusionConfig { Ok(unet) } - pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> { - ddim::DDIMScheduler::new(n_steps, self.scheduler) + pub fn build_scheduler(&self, n_steps: usize) -> Result<Box<dyn Scheduler>> { + self.scheduler.build(n_steps) } } diff --git a/candle-transformers/src/models/stable_diffusion/schedulers.rs b/candle-transformers/src/models/stable_diffusion/schedulers.rs index 3f6a1d72..0f0441e0 100644 --- a/candle-transformers/src/models/stable_diffusion/schedulers.rs +++ b/candle-transformers/src/models/stable_diffusion/schedulers.rs @@ -3,9 +3,25 @@ //! //! Noise schedulers can be used to set the trade-off between //! inference speed and quality. - use candle::{Result, Tensor}; +pub trait SchedulerConfig: std::fmt::Debug { + fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>>; +} + +/// This trait represents a scheduler for the diffusion process. +pub trait Scheduler { + fn timesteps(&self) -> &[usize]; + + fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor>; + + fn init_noise_sigma(&self) -> f64; + + fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>; + + fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>; +} + /// This represents how beta ranges from its minimum value to the maximum /// during training. #[derive(Debug, Clone, Copy)] @@ -25,6 +41,22 @@ pub enum PredictionType { Sample, } +/// Time step spacing for the diffusion process. +/// +/// "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 +#[derive(Debug, Clone, Copy)] +pub enum TimestepSpacing { + Leading, + Linspace, + Trailing, +} + +impl Default for TimestepSpacing { + fn default() -> Self { + Self::Leading + } +} + /// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of /// `(1-beta)` over time from `t = [0,1]`. /// diff --git a/candle-transformers/src/models/stable_diffusion/utils.rs b/candle-transformers/src/models/stable_diffusion/utils.rs index cef06f1c..5b5fa0f7 100644 --- a/candle-transformers/src/models/stable_diffusion/utils.rs +++ b/candle-transformers/src/models/stable_diffusion/utils.rs @@ -13,3 +13,49 @@ pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> { Tensor::from_vec(vs, steps, &Device::Cpu) } } + +/// A linear interpolator for a sorted array of x and y values. +struct LinearInterpolator<'x, 'y> { + xp: &'x [f64], + fp: &'y [f64], + cache: usize, +} + +impl<'x, 'y> LinearInterpolator<'x, 'y> { + fn accel_find(&mut self, x: f64) -> usize { + let xidx = self.cache; + if x < self.xp[xidx] { + self.cache = self.xp[0..xidx].partition_point(|o| *o < x); + self.cache = self.cache.saturating_sub(1); + } else if x >= self.xp[xidx + 1] { + self.cache = self.xp[xidx..self.xp.len()].partition_point(|o| *o < x) + xidx; + self.cache = self.cache.saturating_sub(1); + } + + self.cache + } + + fn eval(&mut self, x: f64) -> f64 { + if x < self.xp[0] || x > self.xp[self.xp.len() - 1] { + return f64::NAN; + } + + let idx = self.accel_find(x); + + let x_l = self.xp[idx]; + let x_h = self.xp[idx + 1]; + let y_l = self.fp[idx]; + let y_h = self.fp[idx + 1]; + let dx = x_h - x_l; + if dx > 0.0 { + y_l + (x - x_l) / dx * (y_h - y_l) + } else { + f64::NAN + } + } +} + +pub fn interp(x: &[f64], xp: &[f64], fp: &[f64]) -> Vec<f64> { + let mut interpolator = LinearInterpolator { xp, fp, cache: 0 }; + x.iter().map(|&x| interpolator.eval(x)).collect() +} |
