summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorcdoko <190060110+cdoko@users.noreply.github.com>2024-12-03 05:56:01 -0400
committerGitHub <noreply@github.com>2024-12-03 10:56:01 +0100
commit145aa7193c4e658b184f52706574cc9f115e4674 (patch)
tree2c5d0a903894b3e854090cdb97f36bdd5ac3a1c0 /candle-transformers
parent6f715f92564c10426c5565cd30ece25aee8d72ac (diff)
downloadcandle-145aa7193c4e658b184f52706574cc9f115e4674.tar.gz
candle-145aa7193c4e658b184f52706574cc9f115e4674.tar.bz2
candle-145aa7193c4e658b184f52706574cc9f115e4674.zip
Add Nvembed v2 model (#2649)
* Update mod.rs * Create mod.rs * Create decoder.rs * Create model.rs * Create main.rs * Create README.md * Update README.md * Update main.rs * Update and rename decoder.rs to embedding.rs * Update mod.rs * Update model.rs
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/mod.rs1
-rw-r--r--candle-transformers/src/models/nvembed_v2/embedding.rs294
-rw-r--r--candle-transformers/src/models/nvembed_v2/mod.rs18
-rw-r--r--candle-transformers/src/models/nvembed_v2/model.rs233
4 files changed, 546 insertions, 0 deletions
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 571a8861..be1f15c4 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -62,6 +62,7 @@ pub mod mobilenetv4;
pub mod mobileone;
pub mod moondream;
pub mod mpt;
+pub mod nvembed_v2;
pub mod olmo;
pub mod openclip;
pub mod paligemma;
diff --git a/candle-transformers/src/models/nvembed_v2/embedding.rs b/candle-transformers/src/models/nvembed_v2/embedding.rs
new file mode 100644
index 00000000..a52192af
--- /dev/null
+++ b/candle-transformers/src/models/nvembed_v2/embedding.rs
@@ -0,0 +1,294 @@
+/// Mistral LLM, https://github.com/mistralai/mistral-src
+use crate::models::{
+ mistral::Config,
+ with_tracing::{linear_no_bias, Linear, RmsNorm},
+};
+use crate::utils::repeat_kv;
+use candle::{DType, Device, Module, Result, Tensor};
+use candle_nn::{Activation, VarBuilder};
+use std::sync::Arc;
+
+#[derive(Debug, Clone)]
+struct RotaryEmbedding {
+ sin: Tensor,
+ cos: Tensor,
+}
+
+impl RotaryEmbedding {
+ fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
+ let rope_theta = cfg.rope_theta as f32;
+ let dim = cfg.hidden_size / cfg.num_attention_heads;
+ let max_seq_len = cfg.max_position_embeddings;
+ let inv_freq: Vec<_> = (0..dim)
+ .step_by(2)
+ .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
+ .collect();
+ let inv_freq_len = inv_freq.len();
+ let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
+ let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
+ .to_dtype(dtype)?
+ .reshape((max_seq_len, 1))?;
+ let freqs = t.matmul(&inv_freq)?;
+ Ok(Self {
+ sin: freqs.sin()?,
+ cos: freqs.cos()?,
+ })
+ }
+
+ fn apply_rotary_emb_qkv(
+ &self,
+ q: &Tensor,
+ k: &Tensor,
+ seqlen_offset: usize,
+ ) -> Result<(Tensor, Tensor)> {
+ let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
+ let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
+ let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
+ let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?;
+ let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?;
+ Ok((q_embed, k_embed))
+ }
+}
+
+#[derive(Debug, Clone)]
+#[allow(clippy::upper_case_acronyms)]
+struct MLP {
+ gate_proj: Linear,
+ up_proj: Linear,
+ down_proj: Linear,
+ act_fn: Activation,
+}
+
+impl MLP {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let hidden_sz = cfg.hidden_size;
+ let intermediate_sz = cfg.intermediate_size;
+ let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
+ let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
+ let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
+ Ok(Self {
+ gate_proj,
+ up_proj,
+ down_proj,
+ act_fn: cfg.hidden_act,
+ })
+ }
+}
+
+impl Module for MLP {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
+ let rhs = xs.apply(&self.up_proj)?;
+ (lhs * rhs)?.apply(&self.down_proj)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Attention {
+ q_proj: Linear,
+ k_proj: Linear,
+ v_proj: Linear,
+ o_proj: Linear,
+ num_heads: usize,
+ num_kv_heads: usize,
+ num_kv_groups: usize,
+ head_dim: usize,
+ hidden_size: usize,
+ rotary_emb: Arc<RotaryEmbedding>,
+}
+
+impl Attention {
+ fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let hidden_sz = cfg.hidden_size;
+ let num_heads = cfg.num_attention_heads;
+ let num_kv_heads = cfg.num_key_value_heads;
+ let num_kv_groups = num_heads / num_kv_heads;
+ let head_dim = hidden_sz / num_heads;
+ let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
+ let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
+ let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
+ let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
+ Ok(Self {
+ q_proj,
+ k_proj,
+ v_proj,
+ o_proj,
+ num_heads,
+ num_kv_heads,
+ num_kv_groups,
+ head_dim,
+ hidden_size: hidden_sz,
+ rotary_emb,
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ attention_mask: Option<&Tensor>,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let (b_sz, q_len, _) = xs.dims3()?;
+
+ let query_states = self.q_proj.forward(xs)?;
+ let key_states = self.k_proj.forward(xs)?;
+ let value_states = self.v_proj.forward(xs)?;
+
+ let query_states = query_states
+ .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
+ .transpose(1, 2)?
+ .contiguous()?;
+
+ let key_states = key_states
+ .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
+ .transpose(1, 2)?
+ .contiguous()?;
+ let value_states = value_states
+ .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
+ .transpose(1, 2)?;
+
+ let (query_states, key_states) =
+ self.rotary_emb
+ .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
+
+ let key_states = repeat_kv(key_states, self.num_kv_groups)?;
+ let value_states = repeat_kv(value_states, self.num_kv_groups)?;
+
+ let scale = 1f64 / f64::sqrt(self.head_dim as f64);
+ let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
+
+ let attn_weights = match attention_mask {
+ None => attn_weights,
+ Some(mask) => attn_weights.broadcast_add(mask)?,
+ };
+ let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
+ let attn_output = attn_weights.matmul(&value_states)?;
+
+ attn_output
+ .transpose(1, 2)?
+ .reshape((b_sz, q_len, self.hidden_size))?
+ .apply(&self.o_proj)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct DecoderLayer {
+ self_attn: Attention,
+ mlp: MLP,
+ input_layernorm: RmsNorm,
+ post_attention_layernorm: RmsNorm,
+}
+
+impl DecoderLayer {
+ fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
+ let mlp = MLP::new(cfg, vb.pp("mlp"))?;
+ let input_layernorm =
+ RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
+ let post_attention_layernorm = RmsNorm::new(
+ cfg.hidden_size,
+ cfg.rms_norm_eps,
+ vb.pp("post_attention_layernorm"),
+ )?;
+ Ok(Self {
+ self_attn,
+ mlp,
+ input_layernorm,
+ post_attention_layernorm,
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ attention_mask: Option<&Tensor>,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let residual = xs;
+ let xs = self.input_layernorm.forward(xs)?;
+
+ let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
+
+ let xs = (xs + residual)?;
+ let residual = &xs;
+ let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
+ residual + xs
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Model {
+ embed_tokens: candle_nn::Embedding,
+ layers: Vec<DecoderLayer>,
+ norm: RmsNorm,
+ pub cfg: Config,
+}
+
+impl Model {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let embed_tokens =
+ candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?;
+ let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
+ let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
+ let vb_l = vb.pp("layers");
+ for layer_idx in 0..cfg.num_hidden_layers {
+ let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
+ layers.push(layer)
+ }
+ let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("norm"))?;
+ Ok(Self {
+ embed_tokens,
+ layers,
+ norm,
+ cfg: cfg.clone(),
+ })
+ }
+
+ // Attn mask used to mask out padding tokens
+ pub fn forward(
+ &mut self,
+ attn_mask: &Tensor,
+ input_ids: &Tensor,
+ dtype: DType,
+ ) -> Result<Tensor> {
+ let mut xs = self.embed_tokens.forward(input_ids)?;
+
+ // Expand to 4d mask for sdpa
+ let attn_mask = prepare_4d_attention_mask(attn_mask, dtype, None)?;
+
+ for layer in self.layers.iter_mut() {
+ xs = layer.forward(&xs, Some(&attn_mask), 0)?;
+ }
+
+ // Return hiddens instead of logits
+ xs.apply(&self.norm)
+ }
+}
+
+fn prepare_4d_attention_mask(
+ mask: &Tensor,
+ dtype: DType,
+ tgt_len: Option<usize>,
+) -> Result<Tensor> {
+ let bsz = mask.dims()[0];
+ let src_len = mask.dims()[1];
+ let tgt_len = tgt_len.unwrap_or(src_len);
+
+ let expanded_mask = mask
+ .unsqueeze(1)?
+ .unsqueeze(2)?
+ .expand((bsz, 1, tgt_len, src_len))?
+ .to_dtype(dtype)?;
+
+ let inverted_mask = (1.0 - expanded_mask)?;
+
+ (inverted_mask * get_dtype_min_val(dtype))?.to_dtype(dtype)
+}
+
+fn get_dtype_min_val(dtype: DType) -> f64 {
+ match dtype {
+ DType::F32 => f32::MIN as f64,
+ DType::F64 => f64::MIN,
+ _ => panic!("Unsupported data type"),
+ }
+}
diff --git a/candle-transformers/src/models/nvembed_v2/mod.rs b/candle-transformers/src/models/nvembed_v2/mod.rs
new file mode 100644
index 00000000..8a8f7007
--- /dev/null
+++ b/candle-transformers/src/models/nvembed_v2/mod.rs
@@ -0,0 +1,18 @@
+//! NV-Embed-v2
+//!
+//! NV-Embed-v2 is a text embedding model that combines a Mistral decoder with a latent attention mechanism to produce high-quality text embeddings.
+//!
+//! This implementation is based on the [paper](https://arxiv.org/pdf/2405.17428) and [weights](https://huggingface.co/nvidia/NV-Embed-v2)
+//!
+//! # Query-Passage Retrieval Example
+//! ```bash
+//! cargo run --example nvembed_v2 --release
+//! ```
+//!
+//! # Sentence Embedding Example
+//! ```bash
+//! cargo run --example nvembed_v2 --release -- --prompt "Here is a test sentence"
+//! ```
+
+pub mod embedding;
+pub mod model;
diff --git a/candle-transformers/src/models/nvembed_v2/model.rs b/candle-transformers/src/models/nvembed_v2/model.rs
new file mode 100644
index 00000000..73ef776e
--- /dev/null
+++ b/candle-transformers/src/models/nvembed_v2/model.rs
@@ -0,0 +1,233 @@
+use super::embedding::Model as EmbeddingModel;
+use crate::models::{
+ mistral::Config,
+ with_tracing::{layer_norm, linear, linear_no_bias, LayerNorm, Linear},
+};
+use candle::{DType, Device, Result, Tensor, D};
+use candle_nn::{ops::softmax_last_dim, LayerNormConfig, Module, VarBuilder};
+
+// Geglu and feedforward from candle-transformers/src/models/stable_diffusion/attention.rs
+#[derive(Debug)]
+struct GeGlu {
+ proj: Linear,
+ span: tracing::Span,
+}
+
+impl GeGlu {
+ fn new(vs: VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
+ let proj = linear(dim_in, dim_out * 2, vs)?;
+ let span = tracing::span!(tracing::Level::TRACE, "geglu");
+ Ok(Self { proj, span })
+ }
+}
+
+impl Module for GeGlu {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
+ &hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
+ }
+}
+
+#[derive(Debug)]
+struct FeedForward {
+ project_in: GeGlu,
+ linear: Linear,
+ span: tracing::Span,
+}
+
+impl FeedForward {
+ fn new(vs: VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> {
+ let inner_dim = dim * mult;
+ let dim_out = dim_out.unwrap_or(dim);
+ let vs = vs.pp("net");
+ let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
+ let linear = linear(inner_dim, dim_out, vs.pp("2"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "ff");
+ Ok(Self {
+ project_in,
+ linear,
+ span,
+ })
+ }
+}
+
+impl Module for FeedForward {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let xs = self.project_in.forward(xs)?;
+ self.linear.forward(&xs)
+ }
+}
+
+// CrossAttention from candle-transformers/src/models/stable_diffusion/attention.rs
+#[derive(Debug)]
+struct CrossAttention {
+ to_q: Linear,
+ to_kv: Linear,
+ to_out: Linear,
+ heads: usize,
+ scale: f64,
+ span: tracing::Span,
+ span_attn: tracing::Span,
+ span_softmax: tracing::Span,
+}
+
+impl CrossAttention {
+ fn new(
+ vs: VarBuilder,
+ query_dim: usize,
+ context_dim: Option<usize>,
+ heads: usize,
+ dim_head: usize,
+ ) -> Result<Self> {
+ let inner_dim = dim_head * heads;
+ let context_dim = context_dim.unwrap_or(query_dim);
+ let scale = 1.0 / f64::sqrt(dim_head as f64);
+ let to_q = linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?;
+ let to_kv = linear_no_bias(context_dim, inner_dim * 2, vs.pp("to_kv"))?;
+ let to_out = linear_no_bias(inner_dim, query_dim, vs.pp("to_out"))?;
+ let span = tracing::span!(tracing::Level::TRACE, "xa");
+ let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
+ let span_softmax = tracing::span!(tracing::Level::TRACE, "xa-softmax");
+ Ok(Self {
+ to_q,
+ to_kv,
+ to_out,
+ heads,
+ scale,
+ span,
+ span_attn,
+ span_softmax,
+ })
+ }
+
+ fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> {
+ let (batch_size, seq_len, dim) = xs.dims3()?;
+ xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))?
+ .transpose(1, 2)?
+ .reshape((batch_size * self.heads, seq_len, dim / self.heads))
+ }
+
+ fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> {
+ let (batch_size, seq_len, dim) = xs.dims3()?;
+ xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))?
+ .transpose(1, 2)?
+ .reshape((batch_size / self.heads, seq_len, dim * self.heads))
+ }
+
+ fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
+ let _enter = self.span_attn.enter();
+
+ let in_dtype = query.dtype();
+ let query = query.to_dtype(DType::F32)?;
+ let key = key.to_dtype(DType::F32)?;
+ let value = value.to_dtype(DType::F32)?;
+ let xs = query.matmul(&(key.t()? * self.scale)?)?;
+ let xs = {
+ let _enter = self.span_softmax.enter();
+ softmax_last_dim(&xs)?
+ };
+ let xs = xs.matmul(&value)?.to_dtype(in_dtype)?;
+
+ self.reshape_batch_dim_to_heads(&xs)
+ }
+
+ fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let query = self.to_q.forward(xs)?;
+ let context = context.unwrap_or(xs).contiguous()?;
+ let kv_chunks = self
+ .to_kv
+ .forward(&context)?
+ .chunk(2, context.shape().dims().len() - 1)?;
+ let (key, value) = (kv_chunks[0].clone(), kv_chunks[1].clone());
+ let query = self.reshape_heads_to_batch_dim(&query)?;
+ let key = self.reshape_heads_to_batch_dim(&key)?;
+ let value = self.reshape_heads_to_batch_dim(&value)?;
+
+ let xs = self.attention(&query, &key, &value)?;
+ self.to_out.forward(&xs)
+ }
+}
+
+#[derive(Debug)]
+pub struct Model {
+ embedding_model: EmbeddingModel,
+ cross_attn: CrossAttention,
+ cross_attn_norm: LayerNorm,
+ cross_attn_context_norm: LayerNorm,
+ ff: FeedForward,
+ ff_norm: LayerNorm,
+ latents: Tensor,
+ pub device: Device,
+ pub dtype: DType,
+}
+
+impl Model {
+ pub fn new(vb: VarBuilder) -> Result<Self> {
+ // Embedding model
+ let cfg = Config::config_7b_v0_1(false);
+ let embedding_model = EmbeddingModel::new(&cfg, vb.pp("embedding_model"))?;
+
+ // Latent attention
+ let dim = 4096;
+ let vb = vb.pp("latent_attention_model");
+ let latents = vb.get((512, dim), "latents")?;
+
+ // Cross attend blocks
+ let vb = vb.pp("cross_attend_blocks");
+ let cross_attn_norm = layer_norm(dim, LayerNormConfig::default(), vb.pp("0.norm"))?;
+ let cross_attn_context_norm = layer_norm(
+ dim,
+ candle_nn::LayerNormConfig::default(),
+ vb.pp("0.norm_context"),
+ )?;
+ let cross_attn = CrossAttention::new(vb.pp("0.fn"), dim, None, 8, 4096)?;
+
+ let ff_norm = layer_norm(dim, LayerNormConfig::default(), vb.pp("1.norm"))?;
+ let ff = FeedForward::new(vb.pp("1.fn"), dim, None, 4)?;
+
+ Ok(Self {
+ embedding_model,
+ cross_attn,
+ cross_attn_norm,
+ cross_attn_context_norm,
+ ff,
+ ff_norm,
+ latents,
+ device: vb.device().clone(),
+ dtype: vb.dtype(),
+ })
+ }
+
+ pub fn forward(
+ &mut self,
+ input_ids: &Tensor,
+ attn_mask: &Tensor,
+ pool_mask: &Tensor,
+ ) -> Result<Tensor> {
+ // Embedding model
+ let hiddens = self
+ .embedding_model
+ .forward(attn_mask, input_ids, self.dtype)?;
+
+ // Latent attention
+ let b = hiddens.dims()[0];
+ let x = self.latents.unsqueeze(0)?.repeat((b, 1, 1))?;
+ let original_hiddens = &hiddens;
+
+ let hiddens = self.cross_attn_norm.forward(original_hiddens)?;
+ let x = self.cross_attn_context_norm.forward(&x)?;
+ let cross_hiddens = (self.cross_attn.forward(&hiddens, Some(&x))? + original_hiddens)?;
+
+ let hiddens = self.ff_norm.forward(&cross_hiddens)?;
+ let hiddens = (self.ff.forward(&hiddens)? + cross_hiddens)?;
+
+ // Mean pooling
+ let hiddens_masked = hiddens.broadcast_mul(&pool_mask.unsqueeze(D::Minus1)?)?;
+ let s = hiddens_masked.sum(1)?;
+ let d = pool_mask.sum_keepdim(1)?;
+ s.broadcast_div(&d)
+ }
+}