diff options
author | cdoko <190060110+cdoko@users.noreply.github.com> | 2024-12-03 05:56:01 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-03 10:56:01 +0100 |
commit | 145aa7193c4e658b184f52706574cc9f115e4674 (patch) | |
tree | 2c5d0a903894b3e854090cdb97f36bdd5ac3a1c0 /candle-transformers/src/models/nvembed_v2/model.rs | |
parent | 6f715f92564c10426c5565cd30ece25aee8d72ac (diff) | |
download | candle-145aa7193c4e658b184f52706574cc9f115e4674.tar.gz candle-145aa7193c4e658b184f52706574cc9f115e4674.tar.bz2 candle-145aa7193c4e658b184f52706574cc9f115e4674.zip |
Add Nvembed v2 model (#2649)
* Update mod.rs
* Create mod.rs
* Create decoder.rs
* Create model.rs
* Create main.rs
* Create README.md
* Update README.md
* Update main.rs
* Update and rename decoder.rs to embedding.rs
* Update mod.rs
* Update model.rs
Diffstat (limited to 'candle-transformers/src/models/nvembed_v2/model.rs')
-rw-r--r-- | candle-transformers/src/models/nvembed_v2/model.rs | 233 |
1 files changed, 233 insertions, 0 deletions
diff --git a/candle-transformers/src/models/nvembed_v2/model.rs b/candle-transformers/src/models/nvembed_v2/model.rs new file mode 100644 index 00000000..73ef776e --- /dev/null +++ b/candle-transformers/src/models/nvembed_v2/model.rs @@ -0,0 +1,233 @@ +use super::embedding::Model as EmbeddingModel; +use crate::models::{ + mistral::Config, + with_tracing::{layer_norm, linear, linear_no_bias, LayerNorm, Linear}, +}; +use candle::{DType, Device, Result, Tensor, D}; +use candle_nn::{ops::softmax_last_dim, LayerNormConfig, Module, VarBuilder}; + +// Geglu and feedforward from candle-transformers/src/models/stable_diffusion/attention.rs +#[derive(Debug)] +struct GeGlu { + proj: Linear, + span: tracing::Span, +} + +impl GeGlu { + fn new(vs: VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> { + let proj = linear(dim_in, dim_out * 2, vs)?; + let span = tracing::span!(tracing::Level::TRACE, "geglu"); + Ok(Self { proj, span }) + } +} + +impl Module for GeGlu { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?; + &hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()? + } +} + +#[derive(Debug)] +struct FeedForward { + project_in: GeGlu, + linear: Linear, + span: tracing::Span, +} + +impl FeedForward { + fn new(vs: VarBuilder, dim: usize, dim_out: Option<usize>, mult: usize) -> Result<Self> { + let inner_dim = dim * mult; + let dim_out = dim_out.unwrap_or(dim); + let vs = vs.pp("net"); + let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?; + let linear = linear(inner_dim, dim_out, vs.pp("2"))?; + let span = tracing::span!(tracing::Level::TRACE, "ff"); + Ok(Self { + project_in, + linear, + span, + }) + } +} + +impl Module for FeedForward { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + let xs = self.project_in.forward(xs)?; + self.linear.forward(&xs) + } +} + +// CrossAttention from candle-transformers/src/models/stable_diffusion/attention.rs +#[derive(Debug)] +struct CrossAttention { + to_q: Linear, + to_kv: Linear, + to_out: Linear, + heads: usize, + scale: f64, + span: tracing::Span, + span_attn: tracing::Span, + span_softmax: tracing::Span, +} + +impl CrossAttention { + fn new( + vs: VarBuilder, + query_dim: usize, + context_dim: Option<usize>, + heads: usize, + dim_head: usize, + ) -> Result<Self> { + let inner_dim = dim_head * heads; + let context_dim = context_dim.unwrap_or(query_dim); + let scale = 1.0 / f64::sqrt(dim_head as f64); + let to_q = linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?; + let to_kv = linear_no_bias(context_dim, inner_dim * 2, vs.pp("to_kv"))?; + let to_out = linear_no_bias(inner_dim, query_dim, vs.pp("to_out"))?; + let span = tracing::span!(tracing::Level::TRACE, "xa"); + let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn"); + let span_softmax = tracing::span!(tracing::Level::TRACE, "xa-softmax"); + Ok(Self { + to_q, + to_kv, + to_out, + heads, + scale, + span, + span_attn, + span_softmax, + }) + } + + fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result<Tensor> { + let (batch_size, seq_len, dim) = xs.dims3()?; + xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))? + .transpose(1, 2)? + .reshape((batch_size * self.heads, seq_len, dim / self.heads)) + } + + fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result<Tensor> { + let (batch_size, seq_len, dim) = xs.dims3()?; + xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))? + .transpose(1, 2)? + .reshape((batch_size / self.heads, seq_len, dim * self.heads)) + } + + fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> { + let _enter = self.span_attn.enter(); + + let in_dtype = query.dtype(); + let query = query.to_dtype(DType::F32)?; + let key = key.to_dtype(DType::F32)?; + let value = value.to_dtype(DType::F32)?; + let xs = query.matmul(&(key.t()? * self.scale)?)?; + let xs = { + let _enter = self.span_softmax.enter(); + softmax_last_dim(&xs)? + }; + let xs = xs.matmul(&value)?.to_dtype(in_dtype)?; + + self.reshape_batch_dim_to_heads(&xs) + } + + fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> { + let _enter = self.span.enter(); + let query = self.to_q.forward(xs)?; + let context = context.unwrap_or(xs).contiguous()?; + let kv_chunks = self + .to_kv + .forward(&context)? + .chunk(2, context.shape().dims().len() - 1)?; + let (key, value) = (kv_chunks[0].clone(), kv_chunks[1].clone()); + let query = self.reshape_heads_to_batch_dim(&query)?; + let key = self.reshape_heads_to_batch_dim(&key)?; + let value = self.reshape_heads_to_batch_dim(&value)?; + + let xs = self.attention(&query, &key, &value)?; + self.to_out.forward(&xs) + } +} + +#[derive(Debug)] +pub struct Model { + embedding_model: EmbeddingModel, + cross_attn: CrossAttention, + cross_attn_norm: LayerNorm, + cross_attn_context_norm: LayerNorm, + ff: FeedForward, + ff_norm: LayerNorm, + latents: Tensor, + pub device: Device, + pub dtype: DType, +} + +impl Model { + pub fn new(vb: VarBuilder) -> Result<Self> { + // Embedding model + let cfg = Config::config_7b_v0_1(false); + let embedding_model = EmbeddingModel::new(&cfg, vb.pp("embedding_model"))?; + + // Latent attention + let dim = 4096; + let vb = vb.pp("latent_attention_model"); + let latents = vb.get((512, dim), "latents")?; + + // Cross attend blocks + let vb = vb.pp("cross_attend_blocks"); + let cross_attn_norm = layer_norm(dim, LayerNormConfig::default(), vb.pp("0.norm"))?; + let cross_attn_context_norm = layer_norm( + dim, + candle_nn::LayerNormConfig::default(), + vb.pp("0.norm_context"), + )?; + let cross_attn = CrossAttention::new(vb.pp("0.fn"), dim, None, 8, 4096)?; + + let ff_norm = layer_norm(dim, LayerNormConfig::default(), vb.pp("1.norm"))?; + let ff = FeedForward::new(vb.pp("1.fn"), dim, None, 4)?; + + Ok(Self { + embedding_model, + cross_attn, + cross_attn_norm, + cross_attn_context_norm, + ff, + ff_norm, + latents, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + pub fn forward( + &mut self, + input_ids: &Tensor, + attn_mask: &Tensor, + pool_mask: &Tensor, + ) -> Result<Tensor> { + // Embedding model + let hiddens = self + .embedding_model + .forward(attn_mask, input_ids, self.dtype)?; + + // Latent attention + let b = hiddens.dims()[0]; + let x = self.latents.unsqueeze(0)?.repeat((b, 1, 1))?; + let original_hiddens = &hiddens; + + let hiddens = self.cross_attn_norm.forward(original_hiddens)?; + let x = self.cross_attn_context_norm.forward(&x)?; + let cross_hiddens = (self.cross_attn.forward(&hiddens, Some(&x))? + original_hiddens)?; + + let hiddens = self.ff_norm.forward(&cross_hiddens)?; + let hiddens = (self.ff.forward(&hiddens)? + cross_hiddens)?; + + // Mean pooling + let hiddens_masked = hiddens.broadcast_mul(&pool_mask.unsqueeze(D::Minus1)?)?; + let s = hiddens_masked.sum(1)?; + let d = pool_mask.sum_keepdim(1)?; + s.broadcast_div(&d) + } +} |