summaryrefslogtreecommitdiff
path: root/candle-transformers/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src')
-rw-r--r--candle-transformers/src/models/bert.rs6
-rw-r--r--candle-transformers/src/models/mixformer.rs34
-rw-r--r--candle-transformers/src/models/mixtral.rs499
-rw-r--r--candle-transformers/src/models/mod.rs1
-rw-r--r--candle-transformers/src/models/quantized_llama.rs171
-rw-r--r--candle-transformers/src/models/quantized_mixformer.rs18
-rw-r--r--candle-transformers/src/models/segment_anything/mask_decoder.rs2
-rw-r--r--candle-transformers/src/models/segment_anything/prompt_encoder.rs8
-rw-r--r--candle-transformers/src/models/stable_diffusion/ddim.rs71
-rw-r--r--candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs235
-rw-r--r--candle-transformers/src/models/stable_diffusion/mod.rs118
-rw-r--r--candle-transformers/src/models/stable_diffusion/schedulers.rs34
-rw-r--r--candle-transformers/src/models/stable_diffusion/utils.rs46
13 files changed, 1183 insertions, 60 deletions
diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs
index d6826a16..51c524f5 100644
--- a/candle-transformers/src/models/bert.rs
+++ b/candle-transformers/src/models/bert.rs
@@ -7,8 +7,9 @@ pub const DTYPE: DType = DType::F32;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "lowercase")]
-enum HiddenAct {
+pub enum HiddenAct {
Gelu,
+ GeluApproximate,
Relu,
}
@@ -28,6 +29,7 @@ impl HiddenActLayer {
match self.act {
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
HiddenAct::Gelu => xs.gelu_erf(),
+ HiddenAct::GeluApproximate => xs.gelu(),
HiddenAct::Relu => xs.relu(),
}
}
@@ -48,7 +50,7 @@ pub struct Config {
num_hidden_layers: usize,
num_attention_heads: usize,
intermediate_size: usize,
- hidden_act: HiddenAct,
+ pub hidden_act: HiddenAct,
hidden_dropout_prob: f64,
max_position_embeddings: usize,
type_vocab_size: usize,
diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs
index e822ca14..b0e2fb88 100644
--- a/candle-transformers/src/models/mixformer.rs
+++ b/candle-transformers/src/models/mixformer.rs
@@ -57,6 +57,22 @@ impl Config {
}
}
+ pub fn v2() -> Self {
+ Self {
+ vocab_size: 51200,
+ n_positions: 2048,
+ n_embd: 2560,
+ n_layer: 32,
+ n_inner: None,
+ n_head: 32,
+ rotary_dim: usize::min(32, 2560 / 32),
+ activation_function: Activation::Gelu,
+ layer_norm_epsilon: 1e-5,
+ tie_word_embeddings: false,
+ pad_vocab_size_multiple: 64,
+ }
+ }
+
// https://huggingface.co/teknium/Puffin-Phi-v2/blob/main/config.json
pub fn puffin_phi_v2() -> Self {
Self {
@@ -372,6 +388,24 @@ pub struct MixFormerSequentialForCausalLM {
}
impl MixFormerSequentialForCausalLM {
+ pub fn new_v2(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let vb_head = vb.pp("lm_head");
+ let vb = vb.pp("transformer");
+ let embedding = Embedding::new(cfg, vb.pp("embd"))?;
+ let mut blocks = Vec::new();
+ for i in 0..cfg.n_layer {
+ let block = ParallelBlock::new(cfg, vb.pp("h").pp(i))?;
+ blocks.push(block)
+ }
+ let head = CausalLMHead::new(cfg, vb_head)?;
+ Ok(Self {
+ embedding,
+ blocks,
+ head,
+ span: tracing::span!(tracing::Level::TRACE, "mixformer"),
+ })
+ }
+
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb = vb.pp("layers");
let embedding = Embedding::new(cfg, vb.pp(0))?;
diff --git a/candle-transformers/src/models/mixtral.rs b/candle-transformers/src/models/mixtral.rs
new file mode 100644
index 00000000..ede74d3f
--- /dev/null
+++ b/candle-transformers/src/models/mixtral.rs
@@ -0,0 +1,499 @@
+use crate::models::with_tracing::{linear_no_bias, Linear};
+/// Mixtral Model
+/// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
+/// https://mistral.ai/news/mixtral-of-experts/
+use candle::{DType, Device, Module, Result, Tensor, D};
+use candle_nn::{Activation, VarBuilder};
+use serde::Deserialize;
+use std::sync::Arc;
+
+/// https://github.com/huggingface/transformers/blob/1a585c1222a56bcaecc070966d558d4a9d862e83/src/transformers/models/mixtral/configuration_mixtral.py#L113
+#[derive(Debug, Clone, PartialEq, Deserialize)]
+pub struct Config {
+ pub(crate) vocab_size: usize,
+ pub(crate) hidden_size: usize,
+ pub(crate) intermediate_size: usize,
+ pub(crate) num_hidden_layers: usize,
+ pub(crate) num_attention_heads: usize,
+ pub(crate) num_key_value_heads: usize,
+ pub(crate) hidden_act: Activation,
+ pub(crate) max_position_embeddings: usize,
+ pub(crate) rms_norm_eps: f64,
+ pub(crate) rope_theta: f64,
+ pub(crate) sliding_window: usize,
+ pub(crate) num_experts_per_tok: usize,
+ pub(crate) num_local_experts: usize,
+ pub(crate) use_flash_attn: bool,
+}
+
+impl Config {
+ /// https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json
+ pub fn v0_1_8x7b(use_flash_attn: bool) -> Self {
+ Self {
+ vocab_size: 32000,
+ hidden_size: 4096,
+ intermediate_size: 14336,
+ num_hidden_layers: 32,
+ num_attention_heads: 32,
+ num_key_value_heads: 8,
+ hidden_act: Activation::Silu,
+ max_position_embeddings: 32768,
+ rms_norm_eps: 1e-5,
+ rope_theta: 1e6,
+ sliding_window: 4096,
+ num_experts_per_tok: 2,
+ num_local_experts: 8,
+ use_flash_attn,
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+struct RmsNorm {
+ inner: candle_nn::RmsNorm,
+ span: tracing::Span,
+}
+
+impl RmsNorm {
+ fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
+ let inner = candle_nn::rms_norm(size, eps, vb)?;
+ Ok(Self { inner, span })
+ }
+}
+
+impl Module for RmsNorm {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ self.inner.forward(x)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct RotaryEmbedding {
+ sin: Tensor,
+ cos: Tensor,
+}
+
+fn rotate_half(xs: &Tensor) -> Result<Tensor> {
+ let last_dim = xs.dim(D::Minus1)?;
+ let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
+ let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
+ Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
+}
+
+impl RotaryEmbedding {
+ fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
+ let dim = cfg.hidden_size / cfg.num_attention_heads;
+ let max_seq_len = cfg.max_position_embeddings;
+ let inv_freq: Vec<_> = (0..dim)
+ .step_by(2)
+ .map(|i| 1f32 / (cfg.rope_theta as f32).powf(i as f32 / dim as f32))
+ .collect();
+ let inv_freq_len = inv_freq.len();
+ let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
+ let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
+ .to_dtype(dtype)?
+ .reshape((max_seq_len, 1))?;
+ let freqs = t.matmul(&inv_freq)?;
+ let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
+ Ok(Self {
+ sin: freqs.sin()?,
+ cos: freqs.cos()?,
+ })
+ }
+
+ fn apply_rotary_emb_qkv(
+ &self,
+ q: &Tensor,
+ k: &Tensor,
+ seqlen_offset: usize,
+ ) -> Result<(Tensor, Tensor)> {
+ let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
+ let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
+ let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
+ let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
+ let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
+ let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
+ let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
+ Ok((q_embed, k_embed))
+ }
+}
+
+#[cfg(feature = "flash-attn")]
+fn flash_attn(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ softmax_scale: f32,
+ causal: bool,
+) -> Result<Tensor> {
+ candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
+}
+
+#[cfg(not(feature = "flash-attn"))]
+fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
+ unimplemented!("compile with '--features flash-attn'")
+}
+
+#[derive(Debug, Clone)]
+struct Attention {
+ q_proj: Linear,
+ k_proj: Linear,
+ v_proj: Linear,
+ o_proj: Linear,
+ num_heads: usize,
+ num_kv_heads: usize,
+ num_kv_groups: usize,
+ head_dim: usize,
+ hidden_size: usize,
+ rotary_emb: Arc<RotaryEmbedding>,
+ kv_cache: Option<(Tensor, Tensor)>,
+ use_flash_attn: bool,
+}
+
+impl Attention {
+ fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let hidden_sz = cfg.hidden_size;
+ let num_heads = cfg.num_attention_heads;
+ let num_kv_heads = cfg.num_key_value_heads;
+ let num_kv_groups = num_heads / num_kv_heads;
+ let head_dim = hidden_sz / num_heads;
+ let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
+ let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
+ let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
+ let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
+ Ok(Self {
+ q_proj,
+ k_proj,
+ v_proj,
+ o_proj,
+ num_heads,
+ num_kv_heads,
+ num_kv_groups,
+ head_dim,
+ hidden_size: hidden_sz,
+ rotary_emb,
+ kv_cache: None,
+ use_flash_attn: cfg.use_flash_attn,
+ })
+ }
+
+ fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
+ let n_rep = self.num_kv_groups;
+ if n_rep == 1 {
+ Ok(xs)
+ } else {
+ let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
+ xs.unsqueeze(2)?
+ .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
+ .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
+ }
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ attention_mask: Option<&Tensor>,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let (b_sz, q_len, _) = xs.dims3()?;
+
+ let query_states = self.q_proj.forward(xs)?;
+ let key_states = self.k_proj.forward(xs)?;
+ let value_states = self.v_proj.forward(xs)?;
+
+ let query_states = query_states
+ .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
+ .transpose(1, 2)?;
+ let key_states = key_states
+ .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
+ .transpose(1, 2)?;
+ let value_states = value_states
+ .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
+ .transpose(1, 2)?;
+
+ let (query_states, key_states) =
+ self.rotary_emb
+ .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
+
+ let (key_states, value_states) = match &self.kv_cache {
+ None => (key_states, value_states),
+ Some((prev_k, prev_v)) => {
+ let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
+ let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
+ (key_states, value_states)
+ }
+ };
+ self.kv_cache = Some((key_states.clone(), value_states.clone()));
+
+ let key_states = self.repeat_kv(key_states)?;
+ let value_states = self.repeat_kv(value_states)?;
+
+ let attn_output = if self.use_flash_attn {
+ // flash-attn expects (b_sz, seq_len, nheads, head_dim)
+ let q = query_states.transpose(1, 2)?;
+ let k = key_states.transpose(1, 2)?;
+ let v = value_states.transpose(1, 2)?;
+ let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
+ flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)?
+ } else {
+ let scale = 1f64 / f64::sqrt(self.head_dim as f64);
+ let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
+
+ let attn_weights = match attention_mask {
+ None => attn_weights,
+ Some(mask) => attn_weights.broadcast_add(mask)?,
+ };
+ let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
+ attn_weights.matmul(&value_states)?
+ };
+ attn_output
+ .transpose(1, 2)?
+ .reshape((b_sz, q_len, self.hidden_size))?
+ .apply(&self.o_proj)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct BlockSparseTop2MLP {
+ w1: Linear,
+ w2: Linear,
+ w3: Linear,
+ act_fn: Activation,
+}
+
+impl BlockSparseTop2MLP {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let hidden_sz = cfg.hidden_size;
+ let intermediate_sz = cfg.intermediate_size;
+ let w1 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w1"))?;
+ let w2 = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("w2"))?;
+ let w3 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w3"))?;
+ Ok(Self {
+ w1,
+ w2,
+ w3,
+ act_fn: cfg.hidden_act,
+ })
+ }
+}
+
+impl Module for BlockSparseTop2MLP {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let lhs = xs.apply(&self.w1)?.apply(&self.act_fn)?;
+ let rhs = xs.apply(&self.w3)?;
+ (lhs * rhs)?.apply(&self.w2)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct SparseMoeBlock {
+ gate: Linear,
+ experts: Vec<BlockSparseTop2MLP>,
+ num_experts_per_tok: usize,
+}
+
+impl SparseMoeBlock {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let gate = linear_no_bias(cfg.hidden_size, cfg.num_local_experts, vb.pp("gate"))?;
+ let mut experts = Vec::with_capacity(cfg.num_local_experts);
+ let vb = vb.pp("experts");
+ for idx in 0..cfg.num_local_experts {
+ let expert = BlockSparseTop2MLP::new(cfg, vb.pp(idx))?;
+ experts.push(expert)
+ }
+ Ok(SparseMoeBlock {
+ gate,
+ experts,
+ num_experts_per_tok: cfg.num_experts_per_tok,
+ })
+ }
+}
+
+impl Module for SparseMoeBlock {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let (b_size, seq_len, hidden_dim) = xs.dims3()?;
+ let xs = xs.reshape(((), hidden_dim))?;
+ let router_logits = xs.apply(&self.gate)?;
+ let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
+
+ // In order to extract topk, we extract the data from the tensor and manipulate it
+ // directly. Maybe we will want to use some custom ops instead at some point.
+ let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
+
+ // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
+ // top_x contains the row indexes to evaluate for each expert.
+ let mut top_x = vec![vec![]; self.experts.len()];
+ let mut selected_rws = vec![vec![]; self.experts.len()];
+ for (row_idx, rw) in routing_weights.iter().enumerate() {
+ let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
+ dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
+ let mut sum_routing_weights = 0f32;
+ for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
+ let expert_idx = expert_idx as usize;
+ let routing_weight = rw[expert_idx];
+ sum_routing_weights += routing_weight;
+ top_x[expert_idx].push(row_idx as u32);
+ }
+ for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
+ let expert_idx = expert_idx as usize;
+ let routing_weight = rw[expert_idx];
+ selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
+ }
+ }
+
+ // routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
+
+ let mut ys = xs.zeros_like()?;
+ for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
+ let top_x = &top_x[expert_idx];
+ if top_x.is_empty() {
+ continue;
+ }
+ let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
+ let selected_rws =
+ Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
+ // Index the correct hidden states and compute the expert hidden state for
+ // the current expert. We need to make sure to multiply the output hidden
+ // states by `routing_weights` on the corresponding tokens (top-1 and top-2)
+ let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
+ // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
+ let current_hidden_states = expert_layer.forward(&current_state)?;
+ let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
+ ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
+ }
+
+ let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
+ Ok(ys)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct DecoderLayer {
+ self_attn: Attention,
+ block_sparse_moe: SparseMoeBlock,
+ input_layernorm: RmsNorm,
+ post_attention_layernorm: RmsNorm,
+}
+
+impl DecoderLayer {
+ fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
+ let block_sparse_moe = SparseMoeBlock::new(cfg, vb.pp("block_sparse_moe"))?;
+ let input_layernorm =
+ RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
+ let post_attention_layernorm = RmsNorm::new(
+ cfg.hidden_size,
+ cfg.rms_norm_eps,
+ vb.pp("post_attention_layernorm"),
+ )?;
+ Ok(Self {
+ self_attn,
+ block_sparse_moe,
+ input_layernorm,
+ post_attention_layernorm,
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ attention_mask: Option<&Tensor>,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let residual = xs;
+ let xs = self.input_layernorm.forward(xs)?;
+ let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
+ let xs = (xs + residual)?;
+ let residual = &xs;
+ let xs = xs
+ .apply(&self.post_attention_layernorm)?
+ .apply(&self.block_sparse_moe)?;
+ residual + xs
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Model {
+ embed_tokens: candle_nn::Embedding,
+ layers: Vec<DecoderLayer>,
+ norm: RmsNorm,
+ lm_head: Linear,
+ sliding_window: usize,
+ device: Device,
+ dtype: DType,
+}
+
+impl Model {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let vb_m = vb.pp("model");
+ let embed_tokens =
+ candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
+ let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
+ let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
+ let vb_l = vb_m.pp("layers");
+ for layer_idx in 0..cfg.num_hidden_layers {
+ let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
+ layers.push(layer)
+ }
+ let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
+ let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
+ Ok(Self {
+ embed_tokens,
+ layers,
+ norm,
+ lm_head,
+ sliding_window: cfg.sliding_window,
+ device: vb.device().clone(),
+ dtype: vb.dtype(),
+ })
+ }
+
+ fn prepare_decoder_attention_mask(
+ &self,
+ b_size: usize,
+ tgt_len: usize,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ // Sliding window mask?
+ let mask: Vec<_> = (0..tgt_len)
+ .flat_map(|i| {
+ (0..tgt_len).map(move |j| {
+ if i < j || j + self.sliding_window < i {
+ f32::NEG_INFINITY
+ } else {
+ 0.
+ }
+ })
+ })
+ .collect();
+ let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
+ let mask = if seqlen_offset > 0 {
+ let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
+ Tensor::cat(&[&mask0, &mask], D::Minus1)?
+ } else {
+ mask
+ };
+ mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
+ .to_dtype(self.dtype)
+ }
+
+ pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
+ let (b_size, seq_len) = input_ids.dims2()?;
+ let attention_mask = if seq_len <= 1 {
+ None
+ } else {
+ let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
+ Some(mask)
+ };
+ let mut xs = self.embed_tokens.forward(input_ids)?;
+ for layer in self.layers.iter_mut() {
+ xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
+ }
+ xs.narrow(1, seq_len - 1, 1)?
+ .apply(&self.norm)?
+ .apply(&self.lm_head)
+ }
+}
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index a9a56673..94a3bd5b 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -14,6 +14,7 @@ pub mod llama2_c_weights;
pub mod marian;
pub mod mistral;
pub mod mixformer;
+pub mod mixtral;
pub mod mpt;
pub mod persimmon;
pub mod quantized_blip;
diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs
index 44d89f40..1fb2d9e2 100644
--- a/candle-transformers/src/models/quantized_llama.rs
+++ b/candle-transformers/src/models/quantized_llama.rs
@@ -48,15 +48,109 @@ impl QMatMul {
}
#[derive(Debug, Clone)]
+struct Mlp {
+ feed_forward_w1: QMatMul,
+ feed_forward_w2: QMatMul,
+ feed_forward_w3: QMatMul,
+}
+
+impl Module for Mlp {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let w1 = self.feed_forward_w1.forward(xs)?;
+ let w3 = self.feed_forward_w3.forward(xs)?;
+ self.feed_forward_w2
+ .forward(&(candle_nn::ops::silu(&w1)? * w3)?)
+ }
+}
+
+#[derive(Debug, Clone)]
+enum MlpOrMoe {
+ Mlp(Mlp),
+ MoE {
+ n_expert_used: usize,
+ feed_forward_gate_inp: QMatMul,
+ experts: Vec<Mlp>,
+ },
+}
+
+impl Module for MlpOrMoe {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ match self {
+ Self::MoE {
+ feed_forward_gate_inp,
+ experts,
+ n_expert_used,
+ } => {
+ let (b_size, seq_len, hidden_dim) = xs.dims3()?;
+ let xs = xs.reshape(((), hidden_dim))?;
+ let router_logits = feed_forward_gate_inp.forward(&xs)?;
+ let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
+
+ // In order to extract topk, we extract the data from the tensor and manipulate it
+ // directly. Maybe we will want to use some custom ops instead at some point.
+ let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
+
+ // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
+ // top_x contains the row indexes to evaluate for each expert.
+ let mut top_x = vec![vec![]; experts.len()];
+ let mut selected_rws = vec![vec![]; experts.len()];
+ for (row_idx, rw) in routing_weights.iter().enumerate() {
+ let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
+ dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
+ let mut sum_routing_weights = 0f32;
+ for &expert_idx in dst.iter().take(*n_expert_used) {
+ let expert_idx = expert_idx as usize;
+ let routing_weight = rw[expert_idx];
+ sum_routing_weights += routing_weight;
+ top_x[expert_idx].push(row_idx as u32);
+ }
+ for &expert_idx in dst.iter().take(*n_expert_used) {
+ let expert_idx = expert_idx as usize;
+ let routing_weight = rw[expert_idx];
+ selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
+ }
+ }
+
+ // routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+ // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
+
+ let mut ys = xs.zeros_like()?;
+ for (expert_idx, expert_layer) in experts.iter().enumerate() {
+ let top_x = &top_x[expert_idx];
+ if top_x.is_empty() {
+ continue;
+ }
+ let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
+ let selected_rws =
+ Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?
+ .reshape(((), 1))?;
+ // Index the correct hidden states and compute the expert hidden state for
+ // the current expert. We need to make sure to multiply the output hidden
+ // states by `routing_weights` on the corresponding tokens (top-1 and top-2)
+ let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
+ // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
+ let current_hidden_states = expert_layer.forward(&current_state)?;
+ let current_hidden_states =
+ current_hidden_states.broadcast_mul(&selected_rws)?;
+ ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
+ }
+
+ let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
+ Ok(ys)
+ }
+ Self::Mlp(mlp) => mlp.forward(xs),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
struct LayerWeights {
attention_wq: QMatMul,
attention_wk: QMatMul,
attention_wv: QMatMul,
attention_wo: QMatMul,
attention_norm: RmsNorm,
- feed_forward_w1: QMatMul,
- feed_forward_w2: QMatMul,
- feed_forward_w3: QMatMul,
+ mlp_or_moe: MlpOrMoe,
ffn_norm: RmsNorm,
n_head: usize,
n_kv_head: usize,
@@ -212,9 +306,16 @@ impl ModelWeights {
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
- let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
- let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
- let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
+ let mlp_or_moe = {
+ let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
+ let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
+ let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
+ MlpOrMoe::Mlp(Mlp {
+ feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
+ feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
+ feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
+ })
+ };
let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
@@ -226,9 +327,7 @@ impl ModelWeights {
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::new(attention_norm, 1e-5)?,
- feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
- feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
- feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
+ mlp_or_moe,
ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?,
n_head: ct.hparams.n_head as usize,
n_kv_head: ct.hparams.n_head as usize / gqa,
@@ -265,6 +364,12 @@ impl ModelWeights {
};
// Parameter extraction from metadata.
+ let n_expert = md_get("llama.expert_count")
+ .and_then(|v| v.to_u32())
+ .unwrap_or(0) as usize;
+ let n_expert_used = md_get("llama.expert_used_count")
+ .and_then(|v| v.to_u32())
+ .unwrap_or(0) as usize;
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
@@ -289,9 +394,38 @@ impl ModelWeights {
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
- let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
- let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
- let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
+ let mlp_or_moe = if n_expert <= 1 {
+ let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
+ let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
+ let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
+ MlpOrMoe::Mlp(Mlp {
+ feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
+ feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
+ feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
+ })
+ } else {
+ let feed_forward_gate_inp =
+ ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"))?;
+ let mut experts = Vec::with_capacity(n_expert);
+ for i in 0..n_expert {
+ let feed_forward_w1 =
+ ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"))?;
+ let feed_forward_w2 =
+ ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"))?;
+ let feed_forward_w3 =
+ ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"))?;
+ experts.push(Mlp {
+ feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
+ feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
+ feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
+ })
+ }
+ MlpOrMoe::MoE {
+ n_expert_used,
+ feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
+ experts,
+ }
+ };
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
@@ -303,9 +437,7 @@ impl ModelWeights {
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?,
- feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
- feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
- feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
+ mlp_or_moe,
ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?,
n_head: head_count,
n_kv_head: head_count_kv,
@@ -360,12 +492,9 @@ impl ModelWeights {
let _enter = layer.span_mlp.enter();
let residual = &x;
let x = layer.ffn_norm.forward(&x)?;
- let w1 = layer.feed_forward_w1.forward(&x)?;
- let w3 = layer.feed_forward_w3.forward(&x)?;
- let mlp = layer
- .feed_forward_w2
- .forward(&(candle_nn::ops::silu(&w1)? * w3)?)?;
- layer_in = (mlp + residual)?;
+ let x = layer.mlp_or_moe.forward(&x)?;
+ let x = (x + residual)?;
+ layer_in = x
}
let x = self.norm.forward(&layer_in)?;
let x = x.i((.., seq_len - 1, ..))?;
diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs
index f11f2036..1a3cd4ac 100644
--- a/candle-transformers/src/models/quantized_mixformer.rs
+++ b/candle-transformers/src/models/quantized_mixformer.rs
@@ -287,6 +287,24 @@ pub struct MixFormerSequentialForCausalLM {
}
impl MixFormerSequentialForCausalLM {
+ pub fn new_v2(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let vb_head = vb.pp("lm_head");
+ let vb = vb.pp("transformer");
+ let embedding = Embedding::new(cfg, vb.pp("embd"))?;
+ let mut blocks = Vec::new();
+ for i in 0..cfg.n_layer {
+ let block = ParallelBlock::new(cfg, vb.pp("h").pp(i))?;
+ blocks.push(block)
+ }
+ let head = CausalLMHead::new(cfg, vb_head)?;
+ Ok(Self {
+ embedding,
+ blocks,
+ head,
+ span: tracing::span!(tracing::Level::TRACE, "mixformer"),
+ })
+ }
+
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb = vb.pp("layers");
let embedding = Embedding::new(cfg, vb.pp(0))?;
diff --git a/candle-transformers/src/models/segment_anything/mask_decoder.rs b/candle-transformers/src/models/segment_anything/mask_decoder.rs
index 2a91cd44..1703c809 100644
--- a/candle-transformers/src/models/segment_anything/mask_decoder.rs
+++ b/candle-transformers/src/models/segment_anything/mask_decoder.rs
@@ -182,7 +182,7 @@ impl MaskDecoder {
sparse_prompt_embeddings: &Tensor,
dense_prompt_embeddings: &Tensor,
) -> Result<(Tensor, Tensor)> {
- // Concatenate ouput tokens.
+ // Concatenate output tokens.
let output_tokens = Tensor::cat(
&[self.iou_token.embeddings(), self.mask_tokens.embeddings()],
0,
diff --git a/candle-transformers/src/models/segment_anything/prompt_encoder.rs b/candle-transformers/src/models/segment_anything/prompt_encoder.rs
index 9d0074b1..16e8a4e8 100644
--- a/candle-transformers/src/models/segment_anything/prompt_encoder.rs
+++ b/candle-transformers/src/models/segment_anything/prompt_encoder.rs
@@ -2,11 +2,11 @@ use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::VarBuilder;
#[derive(Debug)]
-struct PostionEmbeddingRandom {
+struct PositionEmbeddingRandom {
positional_encoding_gaussian_matrix: Tensor,
}
-impl PostionEmbeddingRandom {
+impl PositionEmbeddingRandom {
fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> {
let positional_encoding_gaussian_matrix =
vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?;
@@ -52,7 +52,7 @@ impl PostionEmbeddingRandom {
#[derive(Debug)]
pub struct PromptEncoder {
- pe_layer: PostionEmbeddingRandom,
+ pe_layer: PositionEmbeddingRandom,
point_embeddings: Vec<candle_nn::Embedding>,
not_a_point_embed: candle_nn::Embedding,
mask_downscaling_conv1: candle_nn::Conv2d,
@@ -76,7 +76,7 @@ impl PromptEncoder {
vb: VarBuilder,
) -> Result<Self> {
let num_points_embeddings = 4;
- let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?;
+ let pe_layer = PositionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?;
let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?;
let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?;
let cfg = candle_nn::Conv2dConfig {
diff --git a/candle-transformers/src/models/stable_diffusion/ddim.rs b/candle-transformers/src/models/stable_diffusion/ddim.rs
index 916b7349..d804ed56 100644
--- a/candle-transformers/src/models/stable_diffusion/ddim.rs
+++ b/candle-transformers/src/models/stable_diffusion/ddim.rs
@@ -7,7 +7,9 @@
//!
//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
//! https://arxiv.org/abs/2010.02502
-use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
+use super::schedulers::{
+ betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig, TimestepSpacing,
+};
use candle::{Result, Tensor};
/// The configuration for the DDIM scheduler.
@@ -29,6 +31,8 @@ pub struct DDIMSchedulerConfig {
pub prediction_type: PredictionType,
/// number of diffusion steps used to train the model
pub train_timesteps: usize,
+ /// time step spacing for the diffusion process
+ pub timestep_spacing: TimestepSpacing,
}
impl Default for DDIMSchedulerConfig {
@@ -41,10 +45,17 @@ impl Default for DDIMSchedulerConfig {
steps_offset: 1,
prediction_type: PredictionType::Epsilon,
train_timesteps: 1000,
+ timestep_spacing: TimestepSpacing::Leading,
}
}
}
+impl SchedulerConfig for DDIMSchedulerConfig {
+ fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> {
+ Ok(Box::new(DDIMScheduler::new(inference_steps, *self)?))
+ }
+}
+
/// The DDIM scheduler.
#[derive(Debug, Clone)]
pub struct DDIMScheduler {
@@ -60,12 +71,32 @@ impl DDIMScheduler {
/// Creates a new DDIM scheduler given the number of steps to be
/// used for inference as well as the number of steps that was used
/// during training.
- pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
+ fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
let step_ratio = config.train_timesteps / inference_steps;
- let timesteps: Vec<usize> = (0..(inference_steps))
- .map(|s| s * step_ratio + config.steps_offset)
- .rev()
- .collect();
+ let timesteps: Vec<usize> = match config.timestep_spacing {
+ TimestepSpacing::Leading => (0..(inference_steps))
+ .map(|s| s * step_ratio + config.steps_offset)
+ .rev()
+ .collect(),
+ TimestepSpacing::Trailing => std::iter::successors(Some(config.train_timesteps), |n| {
+ if *n > step_ratio {
+ Some(n - step_ratio)
+ } else {
+ None
+ }
+ })
+ .map(|n| n - 1)
+ .collect(),
+ TimestepSpacing::Linspace => {
+ super::utils::linspace(0.0, (config.train_timesteps - 1) as f64, inference_steps)?
+ .to_vec1::<f64>()?
+ .iter()
+ .map(|&f| f as usize)
+ .rev()
+ .collect()
+ }
+ };
+
let betas = match config.beta_schedule {
BetaSchedule::ScaledLinear => super::utils::linspace(
config.beta_start.sqrt(),
@@ -92,19 +123,11 @@ impl DDIMScheduler {
config,
})
}
+}
- pub fn timesteps(&self) -> &[usize] {
- self.timesteps.as_slice()
- }
-
- /// Ensures interchangeability with schedulers that need to scale the denoising model input
- /// depending on the current timestep.
- pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
- Ok(sample)
- }
-
+impl Scheduler for DDIMScheduler {
/// Performs a backward step during inference.
- pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
+ fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {
@@ -163,7 +186,17 @@ impl DDIMScheduler {
}
}
- pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
+ /// Ensures interchangeability with schedulers that need to scale the denoising model input
+ /// depending on the current timestep.
+ fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
+ Ok(sample)
+ }
+
+ fn timesteps(&self) -> &[usize] {
+ self.timesteps.as_slice()
+ }
+
+ fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {
@@ -174,7 +207,7 @@ impl DDIMScheduler {
(original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)?
}
- pub fn init_noise_sigma(&self) -> f64 {
+ fn init_noise_sigma(&self) -> f64 {
self.init_noise_sigma
}
}
diff --git a/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs
new file mode 100644
index 00000000..9576c2de
--- /dev/null
+++ b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs
@@ -0,0 +1,235 @@
+//! Ancestral sampling with Euler method steps.
+//!
+//! Reference implementation in Rust:
+//!
+//! https://github.com/pykeio/diffusers/blob/250b9ad1898af41e76a74c0d8d4292652823338a/src/schedulers/euler_ancestral_discrete.rs
+//!
+//! Based on the original [`k-diffusion` implementation by Katherine Crowson][kd].
+///
+/// [kd]: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
+use super::{
+ schedulers::{
+ betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig,
+ TimestepSpacing,
+ },
+ utils::interp,
+};
+use candle::{bail, Error, Result, Tensor};
+
+/// The configuration for the EulerAncestral Discrete scheduler.
+#[derive(Debug, Clone, Copy)]
+pub struct EulerAncestralDiscreteSchedulerConfig {
+ /// The value of beta at the beginning of training.n
+ pub beta_start: f64,
+ /// The value of beta at the end of training.
+ pub beta_end: f64,
+ /// How beta evolved during training.
+ pub beta_schedule: BetaSchedule,
+ /// Adjust the indexes of the inference schedule by this value.
+ pub steps_offset: usize,
+ /// prediction type of the scheduler function, one of `epsilon` (predicting
+ /// the noise of the diffusion process), `sample` (directly predicting the noisy sample`)
+ /// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf)
+ pub prediction_type: PredictionType,
+ /// number of diffusion steps used to train the model
+ pub train_timesteps: usize,
+ /// time step spacing for the diffusion process
+ pub timestep_spacing: TimestepSpacing,
+}
+
+impl Default for EulerAncestralDiscreteSchedulerConfig {
+ fn default() -> Self {
+ Self {
+ beta_start: 0.00085f64,
+ beta_end: 0.012f64,
+ beta_schedule: BetaSchedule::ScaledLinear,
+ steps_offset: 1,
+ prediction_type: PredictionType::Epsilon,
+ train_timesteps: 1000,
+ timestep_spacing: TimestepSpacing::Leading,
+ }
+ }
+}
+
+impl SchedulerConfig for EulerAncestralDiscreteSchedulerConfig {
+ fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> {
+ Ok(Box::new(EulerAncestralDiscreteScheduler::new(
+ inference_steps,
+ *self,
+ )?))
+ }
+}
+
+/// The EulerAncestral Discrete scheduler.
+#[derive(Debug, Clone)]
+pub struct EulerAncestralDiscreteScheduler {
+ timesteps: Vec<usize>,
+ sigmas: Vec<f64>,
+ init_noise_sigma: f64,
+ pub config: EulerAncestralDiscreteSchedulerConfig,
+}
+
+// clip_sample: False, set_alpha_to_one: False
+impl EulerAncestralDiscreteScheduler {
+ /// Creates a new EulerAncestral Discrete scheduler given the number of steps to be
+ /// used for inference as well as the number of steps that was used
+ /// during training.
+ pub fn new(
+ inference_steps: usize,
+ config: EulerAncestralDiscreteSchedulerConfig,
+ ) -> Result<Self> {
+ let step_ratio = config.train_timesteps / inference_steps;
+ let timesteps: Vec<usize> = match config.timestep_spacing {
+ TimestepSpacing::Leading => (0..(inference_steps))
+ .map(|s| s * step_ratio + config.steps_offset)
+ .rev()
+ .collect(),
+ TimestepSpacing::Trailing => std::iter::successors(Some(config.train_timesteps), |n| {
+ if *n > step_ratio {
+ Some(n - step_ratio)
+ } else {
+ None
+ }
+ })
+ .map(|n| n - 1)
+ .collect(),
+ TimestepSpacing::Linspace => {
+ super::utils::linspace(0.0, (config.train_timesteps - 1) as f64, inference_steps)?
+ .to_vec1::<f64>()?
+ .iter()
+ .map(|&f| f as usize)
+ .rev()
+ .collect()
+ }
+ };
+
+ let betas = match config.beta_schedule {
+ BetaSchedule::ScaledLinear => super::utils::linspace(
+ config.beta_start.sqrt(),
+ config.beta_end.sqrt(),
+ config.train_timesteps,
+ )?
+ .sqr()?,
+ BetaSchedule::Linear => {
+ super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
+ }
+ BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
+ };
+ let betas = betas.to_vec1::<f64>()?;
+ let mut alphas_cumprod = Vec::with_capacity(betas.len());
+ for &beta in betas.iter() {
+ let alpha = 1.0 - beta;
+ alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
+ }
+ let sigmas: Vec<f64> = alphas_cumprod
+ .iter()
+ .map(|&f| ((1. - f) / f).sqrt())
+ .collect();
+
+ let sigmas_xa: Vec<_> = (0..sigmas.len()).map(|i| i as f64).collect();
+
+ let mut sigmas_int = interp(
+ &timesteps.iter().map(|&t| t as f64).collect::<Vec<_>>(),
+ &sigmas_xa,
+ &sigmas,
+ );
+ sigmas_int.push(0.0);
+
+ // standard deviation of the initial noise distribution
+ // f64 does not implement Ord such that there is no `max`, so we need to use this workaround
+ let init_noise_sigma = *sigmas_int
+ .iter()
+ .chain(std::iter::once(&0.0))
+ .reduce(|a, b| if a > b { a } else { b })
+ .expect("init_noise_sigma could not be reduced from sigmas - this should never happen");
+
+ Ok(Self {
+ sigmas: sigmas_int,
+ timesteps,
+ init_noise_sigma,
+ config,
+ })
+ }
+}
+
+impl Scheduler for EulerAncestralDiscreteScheduler {
+ fn timesteps(&self) -> &[usize] {
+ self.timesteps.as_slice()
+ }
+
+ /// Ensures interchangeability with schedulers that need to scale the denoising model input
+ /// depending on the current timestep.
+ ///
+ /// Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm
+ fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result<Tensor> {
+ let step_index = match self.timesteps.iter().position(|&t| t == timestep) {
+ Some(i) => i,
+ None => bail!("timestep out of this schedulers bounds: {timestep}"),
+ };
+
+ let sigma = self
+ .sigmas
+ .get(step_index)
+ .expect("step_index out of sigma bounds - this shouldn't happen");
+
+ sample / ((sigma.powi(2) + 1.).sqrt())
+ }
+
+ /// Performs a backward step during inference.
+ fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
+ let step_index = self
+ .timesteps
+ .iter()
+ .position(|&p| p == timestep)
+ .ok_or_else(|| Error::Msg("timestep out of this schedulers bounds".to_string()))?;
+
+ let sigma_from = &self.sigmas[step_index];
+ let sigma_to = &self.sigmas[step_index + 1];
+
+ // 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ let pred_original_sample = match self.config.prediction_type {
+ PredictionType::Epsilon => (sample - (model_output * *sigma_from))?,
+ PredictionType::VPrediction => {
+ ((model_output * (-sigma_from / (sigma_from.powi(2) + 1.0).sqrt()))?
+ + (sample / (sigma_from.powi(2) + 1.0))?)?
+ }
+ PredictionType::Sample => bail!("prediction_type not implemented yet: sample"),
+ };
+
+ let sigma_up = (sigma_to.powi(2) * (sigma_from.powi(2) - sigma_to.powi(2))
+ / sigma_from.powi(2))
+ .sqrt();
+ let sigma_down = (sigma_to.powi(2) - sigma_up.powi(2)).sqrt();
+
+ // 2. convert to a ODE derivative
+ let derivative = ((sample - pred_original_sample)? / *sigma_from)?;
+ let dt = sigma_down - *sigma_from;
+ let prev_sample = (sample + derivative * dt)?;
+
+ let noise = prev_sample.randn_like(0.0, 1.0)?;
+
+ prev_sample + noise * sigma_up
+ }
+
+ fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
+ let step_index = self
+ .timesteps
+ .iter()
+ .position(|&p| p == timestep)
+ .ok_or_else(|| Error::Msg("timestep out of this schedulers bounds".to_string()))?;
+
+ let sigma = self
+ .sigmas
+ .get(step_index)
+ .expect("step_index out of sigma bounds - this shouldn't happen");
+
+ original + (noise * *sigma)?
+ }
+
+ fn init_noise_sigma(&self) -> f64 {
+ match self.config.timestep_spacing {
+ TimestepSpacing::Trailing | TimestepSpacing::Linspace => self.init_noise_sigma,
+ TimestepSpacing::Leading => (self.init_noise_sigma.powi(2) + 1.0).sqrt(),
+ }
+ }
+}
diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs
index 66ef7149..30f23975 100644
--- a/candle-transformers/src/models/stable_diffusion/mod.rs
+++ b/candle-transformers/src/models/stable_diffusion/mod.rs
@@ -3,6 +3,7 @@ pub mod clip;
pub mod ddim;
pub mod ddpm;
pub mod embeddings;
+pub mod euler_ancestral_discrete;
pub mod resnet;
pub mod schedulers;
pub mod unet_2d;
@@ -10,9 +11,13 @@ pub mod unet_2d_blocks;
pub mod utils;
pub mod vae;
+use std::sync::Arc;
+
use candle::{DType, Device, Result};
use candle_nn as nn;
+use self::schedulers::{Scheduler, SchedulerConfig};
+
#[derive(Clone, Debug)]
pub struct StableDiffusionConfig {
pub width: usize,
@@ -21,7 +26,7 @@ pub struct StableDiffusionConfig {
pub clip2: Option<clip::Config>,
autoencoder: vae::AutoEncoderKLConfig,
unet: unet_2d::UNet2DConditionModelConfig,
- scheduler: ddim::DDIMSchedulerConfig,
+ scheduler: Arc<dyn SchedulerConfig>,
}
impl StableDiffusionConfig {
@@ -75,13 +80,18 @@ impl StableDiffusionConfig {
512
};
- Self {
+ let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
+ prediction_type: schedulers::PredictionType::Epsilon,
+ ..Default::default()
+ });
+
+ StableDiffusionConfig {
width,
height,
clip: clip::Config::v1_5(),
clip2: None,
autoencoder,
- scheduler: Default::default(),
+ scheduler,
unet,
}
}
@@ -124,10 +134,10 @@ impl StableDiffusionConfig {
latent_channels: 4,
norm_num_groups: 32,
};
- let scheduler = ddim::DDIMSchedulerConfig {
+ let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
prediction_type,
..Default::default()
- };
+ });
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
@@ -143,7 +153,7 @@ impl StableDiffusionConfig {
768
};
- Self {
+ StableDiffusionConfig {
width,
height,
clip: clip::Config::v2_1(),
@@ -205,10 +215,10 @@ impl StableDiffusionConfig {
latent_channels: 4,
norm_num_groups: 32,
};
- let scheduler = ddim::DDIMSchedulerConfig {
+ let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
prediction_type,
..Default::default()
- };
+ });
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
@@ -224,6 +234,76 @@ impl StableDiffusionConfig {
1024
};
+ StableDiffusionConfig {
+ width,
+ height,
+ clip: clip::Config::sdxl(),
+ clip2: Some(clip::Config::sdxl2()),
+ autoencoder,
+ scheduler,
+ unet,
+ }
+ }
+
+ fn sdxl_turbo_(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ prediction_type: schedulers::PredictionType,
+ ) -> Self {
+ let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ };
+ // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/unet/config.json
+ let unet = unet_2d::UNet2DConditionModelConfig {
+ blocks: vec![
+ bc(320, None, 5),
+ bc(640, Some(2), 10),
+ bc(1280, Some(10), 20),
+ ],
+ center_input_sample: false,
+ cross_attention_dim: 2048,
+ downsample_padding: 1,
+ flip_sin_to_cos: true,
+ freq_shift: 0.,
+ layers_per_block: 2,
+ mid_block_scale_factor: 1.,
+ norm_eps: 1e-5,
+ norm_num_groups: 32,
+ sliced_attention_size,
+ use_linear_projection: true,
+ };
+ // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/vae/config.json
+ let autoencoder = vae::AutoEncoderKLConfig {
+ block_out_channels: vec![128, 256, 512, 512],
+ layers_per_block: 2,
+ latent_channels: 4,
+ norm_num_groups: 32,
+ };
+ let scheduler = Arc::new(
+ euler_ancestral_discrete::EulerAncestralDiscreteSchedulerConfig {
+ prediction_type,
+ timestep_spacing: schedulers::TimestepSpacing::Trailing,
+ ..Default::default()
+ },
+ );
+
+ let height = if let Some(height) = height {
+ assert_eq!(height % 8, 0, "height has to be divisible by 8");
+ height
+ } else {
+ 512
+ };
+
+ let width = if let Some(width) = width {
+ assert_eq!(width % 8, 0, "width has to be divisible by 8");
+ width
+ } else {
+ 512
+ };
+
Self {
width,
height,
@@ -249,6 +329,20 @@ impl StableDiffusionConfig {
)
}
+ pub fn sdxl_turbo(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ ) -> Self {
+ Self::sdxl_turbo_(
+ sliced_attention_size,
+ height,
+ width,
+ // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/scheduler/scheduler_config.json
+ schedulers::PredictionType::Epsilon,
+ )
+ }
+
pub fn ssd1b(
sliced_attention_size: Option<usize>,
height: Option<usize>,
@@ -285,9 +379,9 @@ impl StableDiffusionConfig {
latent_channels: 4,
norm_num_groups: 32,
};
- let scheduler = ddim::DDIMSchedulerConfig {
+ let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
..Default::default()
- };
+ });
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
@@ -347,8 +441,8 @@ impl StableDiffusionConfig {
Ok(unet)
}
- pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
- ddim::DDIMScheduler::new(n_steps, self.scheduler)
+ pub fn build_scheduler(&self, n_steps: usize) -> Result<Box<dyn Scheduler>> {
+ self.scheduler.build(n_steps)
}
}
diff --git a/candle-transformers/src/models/stable_diffusion/schedulers.rs b/candle-transformers/src/models/stable_diffusion/schedulers.rs
index 3f6a1d72..0f0441e0 100644
--- a/candle-transformers/src/models/stable_diffusion/schedulers.rs
+++ b/candle-transformers/src/models/stable_diffusion/schedulers.rs
@@ -3,9 +3,25 @@
//!
//! Noise schedulers can be used to set the trade-off between
//! inference speed and quality.
-
use candle::{Result, Tensor};
+pub trait SchedulerConfig: std::fmt::Debug {
+ fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>>;
+}
+
+/// This trait represents a scheduler for the diffusion process.
+pub trait Scheduler {
+ fn timesteps(&self) -> &[usize];
+
+ fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor>;
+
+ fn init_noise_sigma(&self) -> f64;
+
+ fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>;
+
+ fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
+}
+
/// This represents how beta ranges from its minimum value to the maximum
/// during training.
#[derive(Debug, Clone, Copy)]
@@ -25,6 +41,22 @@ pub enum PredictionType {
Sample,
}
+/// Time step spacing for the diffusion process.
+///
+/// "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+#[derive(Debug, Clone, Copy)]
+pub enum TimestepSpacing {
+ Leading,
+ Linspace,
+ Trailing,
+}
+
+impl Default for TimestepSpacing {
+ fn default() -> Self {
+ Self::Leading
+ }
+}
+
/// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
/// `(1-beta)` over time from `t = [0,1]`.
///
diff --git a/candle-transformers/src/models/stable_diffusion/utils.rs b/candle-transformers/src/models/stable_diffusion/utils.rs
index cef06f1c..5b5fa0f7 100644
--- a/candle-transformers/src/models/stable_diffusion/utils.rs
+++ b/candle-transformers/src/models/stable_diffusion/utils.rs
@@ -13,3 +13,49 @@ pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
Tensor::from_vec(vs, steps, &Device::Cpu)
}
}
+
+/// A linear interpolator for a sorted array of x and y values.
+struct LinearInterpolator<'x, 'y> {
+ xp: &'x [f64],
+ fp: &'y [f64],
+ cache: usize,
+}
+
+impl<'x, 'y> LinearInterpolator<'x, 'y> {
+ fn accel_find(&mut self, x: f64) -> usize {
+ let xidx = self.cache;
+ if x < self.xp[xidx] {
+ self.cache = self.xp[0..xidx].partition_point(|o| *o < x);
+ self.cache = self.cache.saturating_sub(1);
+ } else if x >= self.xp[xidx + 1] {
+ self.cache = self.xp[xidx..self.xp.len()].partition_point(|o| *o < x) + xidx;
+ self.cache = self.cache.saturating_sub(1);
+ }
+
+ self.cache
+ }
+
+ fn eval(&mut self, x: f64) -> f64 {
+ if x < self.xp[0] || x > self.xp[self.xp.len() - 1] {
+ return f64::NAN;
+ }
+
+ let idx = self.accel_find(x);
+
+ let x_l = self.xp[idx];
+ let x_h = self.xp[idx + 1];
+ let y_l = self.fp[idx];
+ let y_h = self.fp[idx + 1];
+ let dx = x_h - x_l;
+ if dx > 0.0 {
+ y_l + (x - x_l) / dx * (y_h - y_l)
+ } else {
+ f64::NAN
+ }
+ }
+}
+
+pub fn interp(x: &[f64], xp: &[f64], fp: &[f64]) -> Vec<f64> {
+ let mut interpolator = LinearInterpolator { xp, fp, cache: 0 };
+ x.iter().map(|&x| interpolator.eval(x)).collect()
+}