summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoriskng <147113485+iskng@users.noreply.github.com>2024-11-29 00:01:08 -0800
committerGitHub <noreply@github.com>2024-11-29 09:01:08 +0100
commit4f59ed38b08b84ed9c52e53f2692a2fc1888f30b (patch)
tree634823b41d3a96309691177788a6fac29d19c7a3
parent54e7fc3c97a6d40e459cee4d4bf2eff5c82390da (diff)
downloadcandle-4f59ed38b08b84ed9c52e53f2692a2fc1888f30b.tar.gz
candle-4f59ed38b08b84ed9c52e53f2692a2fc1888f30b.tar.bz2
candle-4f59ed38b08b84ed9c52e53f2692a2fc1888f30b.zip
Adds support for stella_en_v5 embedding model -400M variant (#2608)
* Adds support for stella_en_v5 embedding model -400M variant * Unified stella * WIP: Unified Stella * Combined stella for both 1.5B and 400M variants * Cargo fmt for the CI * removed redundant stella-400m model and example after merge into stella-en-v5 * cargo fmt --all --------- Co-authored-by: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> Co-authored-by: laurent <laurent.mazare@gmail.com>
-rw-r--r--candle-examples/examples/stella-en-v5/README.md24
-rw-r--r--candle-examples/examples/stella-en-v5/main.rs74
-rw-r--r--candle-transformers/src/models/stella_en_v5.rs569
3 files changed, 555 insertions, 112 deletions
diff --git a/candle-examples/examples/stella-en-v5/README.md b/candle-examples/examples/stella-en-v5/README.md
index 5fcc67c3..3a87b295 100644
--- a/candle-examples/examples/stella-en-v5/README.md
+++ b/candle-examples/examples/stella-en-v5/README.md
@@ -21,7 +21,7 @@ Stella_en_1.5B_v5 is trained by [MRL](https://arxiv.org/abs/2205.13147) enabling
The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) for a retrieval task (s2p). The sample queries and docs are hardcoded in the example.
```bash
-$ cargo run --example stella-en-v5 --release --features <metal | cuda>
+$ cargo run --example stella-en-v5 --release --features <metal | cuda> -- --which 1.5b
>
> Score: 0.8178786
@@ -37,9 +37,29 @@ $ cargo run --example stella-en-v5 --release --features <metal | cuda>
> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types >
> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.
>
+
+$ cargo run --example stella-en-v5 --release --features <metal | cuda> -- --which 400m
+
+>
+> Score: 0.8397539
+> Query: What are some ways to reduce stress?
+> Answer: There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending
+> time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent
+> stress from building up.
+>
+>
+>
+> Score: 0.809545
+> Query: What are the benefits of drinking green tea?
+> Answer: Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage
+> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types
+> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.
+>
```
## Supported options:
-- `Stella_en_15B_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`.
+- `Stella_en_v5` has 2 model variants published - a 1.5B variant and 400M variant. This is enabled through the flag `--which`. E.g. `--which 400m` or `--which 1.5b`.
+
+- `Stella_en_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`.
- As per the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5), the model has been primarily trained on `s2s` (similarity) and `s2p` (retrieval) tasks. These require a slightly different `query` preprocessing (a different prompt template for each). In this example this is enabled though `--task` option. \ No newline at end of file
diff --git a/candle-examples/examples/stella-en-v5/main.rs b/candle-examples/examples/stella-en-v5/main.rs
index 2408262b..68ed7e70 100644
--- a/candle-examples/examples/stella-en-v5/main.rs
+++ b/candle-examples/examples/stella-en-v5/main.rs
@@ -212,6 +212,14 @@ impl EncodeTask {
}
}
+#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
+enum Which {
+ #[value(name = "1.5b")]
+ Large,
+ #[value(name = "400m")]
+ Small,
+}
+
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@@ -219,6 +227,9 @@ struct Args {
#[arg(long)]
cpu: bool,
+ #[arg(long)]
+ which: Which,
+
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
@@ -250,24 +261,33 @@ struct Args {
// Tokenizer creation is super critical in our case.
// We are going to be `padding: Left` for each batch
-fn create_tokenizer(tokenizer_file: &Path) -> Result<Tokenizer> {
+fn create_tokenizer(tokenizer_file: &Path, which: Which) -> Result<Tokenizer> {
let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
- let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") {
- pad_id
- } else {
- return Err(anyhow!(
- "Tokenizer doesn't contain expected `<|endoftext|>` token"
- ));
- };
- // This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding
- tokenizer.with_padding(Some(PaddingParams {
- strategy: PaddingStrategy::BatchLongest,
- direction: PaddingDirection::Left,
- pad_id,
- pad_token: "<|endoftext|>".to_string(),
- ..Default::default()
- }));
+ if which == Which::Large {
+ let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") {
+ pad_id
+ } else {
+ return Err(anyhow!(
+ "Tokenizer doesn't contain expected `<|endoftext|>` token"
+ ));
+ };
+
+ // This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding
+ tokenizer.with_padding(Some(PaddingParams {
+ strategy: PaddingStrategy::BatchLongest,
+ direction: PaddingDirection::Left,
+ pad_id,
+ pad_token: "<|endoftext|>".to_string(),
+ ..Default::default()
+ }));
+ } else {
+ tokenizer.with_padding(Some(PaddingParams {
+ strategy: PaddingStrategy::BatchLongest,
+ direction: PaddingDirection::Right,
+ ..Default::default()
+ }));
+ }
Ok(tokenizer)
}
@@ -298,7 +318,19 @@ fn main() -> Result<()> {
Some(d) => d,
None => EmbedDim::Dim1024,
};
- let repo = api.repo(Repo::model("dunzhang/stella_en_1.5B_v5".to_string()));
+
+ let (repo, cfg) = match args.which {
+ Which::Large => (
+ "dunzhang/stella_en_1.5B_v5",
+ Config::new_1_5_b_v5(embed_dim.embed_dim()),
+ ),
+ Which::Small => (
+ "dunzhang/stella_en_400M_v5",
+ Config::new_400_m_v5(embed_dim.embed_dim()),
+ ),
+ };
+
+ let repo = api.repo(Repo::model(repo.to_string()));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
@@ -330,7 +362,7 @@ fn main() -> Result<()> {
println!("retrieved the files in {:?}", start.elapsed());
// Initializing the tokenizer which would require us to add padding to the `left` for batch encoding
- let tokenizer = create_tokenizer(tokenizer_filename.as_path())?;
+ let tokenizer = create_tokenizer(tokenizer_filename.as_path(), args.which)?;
let start = std::time::Instant::now();
@@ -343,11 +375,7 @@ fn main() -> Result<()> {
let embed_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? };
- let model = EmbeddingModel::new(
- &Config::new_1_5_b_v5(embed_dim.embed_dim()),
- base_vb,
- embed_vb,
- )?;
+ let model = EmbeddingModel::new(&cfg, base_vb, embed_vb)?;
println!("loaded the model in {:?}", start.elapsed());
diff --git a/candle-transformers/src/models/stella_en_v5.rs b/candle-transformers/src/models/stella_en_v5.rs
index 7c1d2b5a..761e44a9 100644
--- a/candle-transformers/src/models/stella_en_v5.rs
+++ b/candle-transformers/src/models/stella_en_v5.rs
@@ -16,33 +16,49 @@
//!
use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
-use candle::{DType, Device, IndexOp, Module, Result, Tensor};
-use candle_nn::{Activation, VarBuilder};
+use candle::{DType, Device, Error, IndexOp, Module, Result, Tensor, D};
+use candle_nn::{layer_norm, Activation, LayerNorm, VarBuilder};
use std::sync::Arc;
+// internal representation for identifying which model is being used
+#[derive(Debug, Copy, Clone, PartialEq, serde::Deserialize)]
+pub enum ModelVariant {
+ Large, // 1.5B
+ Small, // 400M
+}
+
+impl Default for ModelVariant {
+ fn default() -> Self {
+ Self::Large
+ }
+}
+
// Same as `qwen2` family of models with the exception being the `embed_head`
// The final `output` causal modelling head is swapped with a learned `dense` layer, `embed_head`
-#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
+#[derive(Debug, Default, Clone, PartialEq, serde::Deserialize)]
pub struct Config {
+ pub variant: ModelVariant,
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
- pub num_key_value_heads: usize,
pub max_position_embeddings: usize,
- pub max_window_layers: usize,
- pub tie_word_embeddings: bool,
pub rope_theta: f64,
- pub rms_norm_eps: f64,
- pub hidden_act: Activation,
pub embed_head: EmbedHead,
+ pub norm_eps: f64, // RMSNorm for 1.5B || LayerNorm for 400M
+ pub activation_fn: Activation, // Silu for 1.5B || Gelu for 400M
+ // Unique to 1.5B
+ pub num_key_value_heads: usize,
+ // Unique to 400M
+ pub type_vocab_size: usize,
+ pub scaling_factor: f64,
}
// Excerpt from `stella` model card:
// `Stella_en_1.5B_v5` models have been trained on [MRL](https://arxiv.org/abs/2205.13147) enabling multiple output dimensions
// Embed head represents the config for various embedding dims supported
-#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
+#[derive(Debug, Default, Clone, PartialEq, serde::Deserialize)]
pub struct EmbedHead {
pub in_features: usize,
pub out_features: usize,
@@ -68,9 +84,9 @@ impl Default for EmbedDim {
}
impl EmbedDim {
- pub fn config(&self) -> EmbedHead {
+ pub fn config(&self, in_features: usize) -> EmbedHead {
EmbedHead {
- in_features: 1536,
+ in_features,
out_features: match &self {
Self::Dim256 => 256,
Self::Dim768 => 768,
@@ -91,7 +107,8 @@ impl Config {
// Representing config.json at https://huggingface.co/dunzhang/stella_en_1.5B_v5/blob/main/config.json
// Removed `sliding_window` related config which is basically being carried forward from `qwen2` but not used here
Self {
- hidden_act: candle_nn::Activation::Silu,
+ variant: ModelVariant::Large,
+ activation_fn: candle_nn::Activation::Silu,
vocab_size: 151646,
hidden_size: 1536,
intermediate_size: 8960,
@@ -99,11 +116,30 @@ impl Config {
num_attention_heads: 12,
num_key_value_heads: 2,
max_position_embeddings: 131072,
- max_window_layers: 21,
- tie_word_embeddings: false,
rope_theta: 1000000.,
- rms_norm_eps: 1e-06,
- embed_head: embed_dim.config(),
+ norm_eps: 1e-06,
+ embed_head: embed_dim.config(1536),
+ ..Default::default()
+ }
+ }
+
+ /// Initialize new `stella_en_400M_v5`
+ pub fn new_400_m_v5(embed_dim: EmbedDim) -> Self {
+ Self {
+ variant: ModelVariant::Small,
+ vocab_size: 30528,
+ hidden_size: 1024,
+ intermediate_size: 4096,
+ num_hidden_layers: 24,
+ num_attention_heads: 16,
+ max_position_embeddings: 8192,
+ type_vocab_size: 2,
+ norm_eps: 1e-12,
+ scaling_factor: 2.0,
+ rope_theta: 160000.0,
+ activation_fn: Activation::Gelu,
+ embed_head: embed_dim.config(1024),
+ ..Default::default()
}
}
}
@@ -117,27 +153,57 @@ struct RotaryEmbedding {
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.hidden_size / cfg.num_attention_heads;
- let max_seq_len = cfg.max_position_embeddings;
+ // Factoring in `scaling factor` for `400M` variant
+ let max_seq_len = if cfg.scaling_factor == 0. {
+ cfg.max_position_embeddings
+ } else {
+ ((cfg.max_position_embeddings as f64) * cfg.scaling_factor) as usize
+ };
+
+ // let rot_dim = if cfg.variant == ModelVariant::Small { dim / 2 } else { dim };
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
- .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
+ .map(|i| {
+ // Scaled rope_theta for 400M variant
+ let rope_theta = if cfg.scaling_factor == 0. {
+ cfg.rope_theta
+ } else {
+ cfg.rope_theta * cfg.scaling_factor
+ };
+ let mut freq = 1. / rope_theta.powf(i as f64 / dim as f64);
+
+ if cfg.scaling_factor != 0. {
+ freq /= cfg.scaling_factor.powf(2.0 / (dim as f64))
+ }
+
+ freq as f32
+ })
.collect();
+
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
+
+ // Calculate position embeddings with scaled sequence length
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
+ // if cfg.variant == ModelVariant::Small {
+ // freqs = Tensor::cat(&[&freqs, &freqs], 1)?
+ // }
+
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
+ // TODO: re-visit this
fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, 0, seq_len)?;
let sin = self.sin.narrow(0, 0, seq_len)?;
+
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
@@ -147,8 +213,9 @@ impl RotaryEmbedding {
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
+ variant: ModelVariant,
gate_proj: Linear,
- up_proj: Linear,
+ up_proj: Option<Linear>, // `up_proj` only for 1.5B variant
down_proj: Linear,
act_fn: Activation,
}
@@ -157,31 +224,65 @@ impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let intermediate_sz = cfg.intermediate_size;
- let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
- let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
- let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
+
+ let (gate_proj, up_proj, down_proj) = match cfg.variant {
+ ModelVariant::Large => (
+ linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?,
+ Some(linear_no_bias(
+ hidden_sz,
+ intermediate_sz,
+ vb.pp("up_proj"),
+ )?),
+ linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?,
+ ),
+ ModelVariant::Small => (
+ linear_no_bias(hidden_sz, intermediate_sz * 2, vb.pp("up_gate_proj"))?,
+ None,
+ linear(intermediate_sz, hidden_sz, vb.pp("down_proj"))?,
+ ),
+ };
+
Ok(Self {
+ variant: cfg.variant,
gate_proj,
up_proj,
down_proj,
- act_fn: cfg.hidden_act,
+ act_fn: cfg.activation_fn,
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
- let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
- let rhs = xs.apply(&self.up_proj)?;
+ let up = self.gate_proj.forward(xs)?;
+
+ let (lhs, rhs) = match self.variant {
+ ModelVariant::Large => {
+ let lhs = up.apply(&self.act_fn)?;
+ let rhs = xs.apply(self.up_proj.as_ref().unwrap())?;
+
+ (lhs, rhs)
+ }
+ ModelVariant::Small => {
+ // Get the dimensions
+ let (_batch_size, _seq_len, hidden_dim) = up.dims3()?;
+ let split_size = hidden_dim / 2;
+
+ // Split along the last dimension (hidden_dim)
+ let up_states = up.narrow(2, 0, split_size)?;
+ let gate = up.narrow(2, split_size, split_size)?.apply(&self.act_fn)?;
+
+ (up_states, gate)
+ }
+ };
+
(lhs * rhs)?.apply(&self.down_proj)
}
}
#[derive(Debug, Clone)]
struct Attention {
- q_proj: Linear,
- k_proj: Linear,
- v_proj: Linear,
+ qkv_proj: Linear,
o_proj: Linear,
num_heads: usize,
num_kv_heads: usize,
@@ -189,6 +290,7 @@ struct Attention {
head_dim: usize,
hidden_size: usize,
rotary_emb: Arc<RotaryEmbedding>,
+ variant: ModelVariant,
}
impl Attention {
@@ -196,16 +298,47 @@ impl Attention {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
- let num_kv_groups = num_heads / num_kv_heads;
+ let num_kv_groups = if num_kv_heads > 0 {
+ num_heads / num_kv_heads
+ } else {
+ 0
+ };
let head_dim = hidden_sz / num_heads;
- let q_proj = linear(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
- let k_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
- let v_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
- let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
+
+ let (qkv_proj, o_proj) = match cfg.variant {
+ ModelVariant::Large => {
+ // The 1.5B variant comes with separate `q, k, v` layers, let's merge it and standardize
+ // Weights
+ let q_w = vb
+ .pp("q_proj")
+ .get((num_heads * head_dim, hidden_sz), "weight")?;
+ let k_w = vb
+ .pp("k_proj")
+ .get((num_kv_heads * head_dim, hidden_sz), "weight")?;
+ let v_w = vb
+ .pp("v_proj")
+ .get((num_kv_heads * head_dim, hidden_sz), "weight")?;
+ // Biases
+ let q_b = vb.pp("q_proj").get(num_heads * head_dim, "bias")?;
+ let k_b = vb.pp("k_proj").get(num_kv_heads * head_dim, "bias")?;
+ let v_b = vb.pp("v_proj").get(num_kv_heads * head_dim, "bias")?;
+
+ let qkv_w = Tensor::cat(&[&q_w, &k_w, &v_w], 0)?;
+ let qkv_b = Tensor::cat(&[&q_b, &k_b, &v_b], 0)?;
+
+ (
+ Linear::from_weights(qkv_w, Some(qkv_b)),
+ linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?,
+ )
+ }
+ ModelVariant::Small => (
+ linear(hidden_sz, 3 * num_heads * head_dim, vb.pp("qkv_proj"))?,
+ linear(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?,
+ ),
+ };
+
Ok(Self {
- q_proj,
- k_proj,
- v_proj,
+ qkv_proj,
o_proj,
num_heads,
num_kv_heads,
@@ -213,45 +346,90 @@ impl Attention {
head_dim,
hidden_size: hidden_sz,
rotary_emb,
+ variant: cfg.variant,
})
}
fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
- let query_states = self.q_proj.forward(xs)?;
- let key_states = self.k_proj.forward(xs)?;
- let value_states = self.v_proj.forward(xs)?;
+ let qkv = self.qkv_proj.forward(xs)?;
- let query_states = query_states
- .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
- .transpose(1, 2)?;
- let key_states = key_states
- .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
- .transpose(1, 2)?;
- let value_states = value_states
- .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
- .transpose(1, 2)?;
+ let n_kv_heads = match self.variant {
+ ModelVariant::Large => self.num_kv_heads,
+ ModelVariant::Small => self.num_heads,
+ };
+
+ let (query_states, key_states, value_states) = match self.variant {
+ ModelVariant::Large => {
+ let q_sz = self.num_heads * self.head_dim;
+ let kv_sz = n_kv_heads * self.head_dim;
+
+ let q = qkv.narrow(D::Minus1, 0, q_sz)?.reshape((
+ b_sz,
+ q_len,
+ self.num_heads,
+ self.head_dim,
+ ))?;
+ let k = qkv.narrow(D::Minus1, q_sz, kv_sz)?.reshape((
+ b_sz,
+ q_len,
+ n_kv_heads,
+ self.head_dim,
+ ))?;
+ let v = qkv.narrow(D::Minus1, q_sz + kv_sz, kv_sz)?.reshape((
+ b_sz,
+ q_len,
+ n_kv_heads,
+ self.head_dim,
+ ))?;
+
+ (q, k, v)
+ }
+ ModelVariant::Small => {
+ // Split into Q, K, V and reshape to match PyTorch shapes
+ let qkv = qkv.reshape((b_sz, q_len, 3, self.num_heads, self.head_dim))?;
+
+ (
+ qkv.i((.., .., 0, .., ..))?,
+ qkv.i((.., .., 1, .., ..))?,
+ qkv.i((.., .., 2, .., ..))?,
+ )
+ }
+ };
+
+ let query_states = query_states.transpose(1, 2)?.contiguous()?;
+ let key_states = key_states.transpose(1, 2)?.contiguous()?;
+ let value_states = value_states.transpose(1, 2)?.contiguous()?;
let (query_states, key_states) = self
.rotary_emb
.apply_rotary_emb_qkv(&query_states, &key_states)?;
- let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
- let value_states =
- crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
+ // The 1.5B is expected to have grouped query attention
+ let (key_states, value_states) = if self.variant == ModelVariant::Large {
+ (
+ crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?,
+ crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?,
+ )
+ } else {
+ (key_states, value_states)
+ };
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
- let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
+ let attn_weights = query_states.matmul(&key_states.transpose(2, 3)?)?;
+ let attn_weights = (attn_weights * scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
+
attn_weights.matmul(&value_states)?
};
+
attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, self.hidden_size))?
@@ -260,70 +438,282 @@ impl Attention {
}
#[derive(Debug, Clone)]
-struct DecoderLayer {
- self_attn: Attention,
+enum NormType {
+ Layer(LayerNorm),
+ Rms(RmsNorm),
+}
+
+#[derive(Debug, Clone)]
+struct Layer {
+ variant: ModelVariant,
+ attention: Attention,
mlp: MLP,
- input_layernorm: RmsNorm,
- post_attention_layernorm: RmsNorm,
+ // For 1.5B: this is `input_layernorm`
+ // For 400M: this is `output_layernorm`
+ layernorm: NormType,
+ post_attention_layernorm: NormType,
}
-impl DecoderLayer {
+impl Layer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
- let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
- let mlp = MLP::new(cfg, vb.pp("mlp"))?;
- let input_layernorm =
- RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
- let post_attention_layernorm = RmsNorm::new(
- cfg.hidden_size,
- cfg.rms_norm_eps,
- vb.pp("post_attention_layernorm"),
+ let attention = Attention::new(
+ rotary_emb,
+ cfg,
+ vb.pp(if cfg.variant == ModelVariant::Large {
+ "self_attn"
+ } else {
+ "attention"
+ }),
)?;
+ let mlp = MLP::new(cfg, vb.pp("mlp"))?;
+ let (layernorm, post_attention_layernorm) = match cfg.variant {
+ ModelVariant::Large => (
+ NormType::Rms(RmsNorm::new(
+ cfg.hidden_size,
+ cfg.norm_eps,
+ vb.pp("input_layernorm"),
+ )?),
+ NormType::Rms(RmsNorm::new(
+ cfg.hidden_size,
+ cfg.norm_eps,
+ vb.pp("post_attention_layernorm"),
+ )?),
+ ),
+ ModelVariant::Small => (
+ NormType::Layer(layer_norm(
+ cfg.hidden_size,
+ candle_nn::LayerNormConfig {
+ eps: cfg.norm_eps,
+ ..Default::default()
+ },
+ vb.pp("mlp_ln"),
+ )?),
+ NormType::Layer(layer_norm(
+ cfg.hidden_size,
+ candle_nn::LayerNormConfig {
+ eps: cfg.norm_eps,
+ ..Default::default()
+ },
+ vb.pp("attn_ln"),
+ )?),
+ ),
+ };
+
Ok(Self {
- self_attn,
+ variant: cfg.variant,
+ attention,
mlp,
- input_layernorm,
+ layernorm,
post_attention_layernorm,
})
}
fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
+ // Here, the application of normalizations and activation calculations differ
+ // For Large [1.5B]:
+ // residual = x
+ // state = other_layernorm(xs)
+ // state = attention(state)
+ // state += residual
+ // residual = state
+ // state = mlp(attention_layernorm(state))
+ // -> residual + state
+ // For Small [400M]:
+ // residual = x;
+ // state = attention(x)
+ // state += residual
+ // state = attention_layernorm(state)
+ // residual = state
+ // state = mlp(state)
+ // state += residual
+ // -> other_layernorm(state)
let residual = xs;
- let xs = self.input_layernorm.forward(xs)?;
- let xs = self.self_attn.forward(&xs, attention_mask)?;
- let xs = (xs + residual)?;
- let residual = &xs;
- let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
- residual + xs
+
+ match self.variant {
+ ModelVariant::Large => {
+ let (attn_ln, input_ln) = if let (NormType::Rms(attn_ln), NormType::Rms(input_ln)) =
+ (&self.post_attention_layernorm, &self.layernorm)
+ {
+ (attn_ln, input_ln)
+ } else {
+ return Err(candle::error::Error::Msg(
+ "Stella 1.5B expects RMSNorm".to_string(),
+ ));
+ };
+
+ let xs = input_ln.forward(xs)?;
+ let xs = (self.attention.forward(&xs, attention_mask)? + residual)?;
+
+ let residual = &xs;
+ let xs = xs.apply(attn_ln)?.apply(&self.mlp)?;
+
+ residual + xs
+ }
+ ModelVariant::Small => {
+ let (attn_ln, output_ln) =
+ if let (NormType::Layer(attn_ln), NormType::Layer(input_ln)) =
+ (&self.post_attention_layernorm, &self.layernorm)
+ {
+ (attn_ln, input_ln)
+ } else {
+ return Err(candle::error::Error::Msg(
+ "Stella 400M expects RMSNorm".to_string(),
+ ));
+ };
+
+ let xs = (self.attention.forward(xs, attention_mask)? + residual)?;
+ let xs = attn_ln.forward(&xs)?;
+
+ let residual = &xs;
+ let xs = (self.mlp.forward(&xs)? + residual)?;
+
+ output_ln.forward(&xs)
+ }
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Embeddings {
+ variant: ModelVariant,
+ // For 1.5B: this is the `embed_tokens`
+ // For 400M: this is the `word_embeddings`
+ embeddings: candle_nn::Embedding,
+ // folloing are specifically for 400M
+ token_type_embeddings: Option<candle_nn::Embedding>,
+ layer_norm: Option<LayerNorm>,
+ position_ids: Option<Tensor>,
+}
+
+impl Embeddings {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let (embeddings, token_type_embeddings, layer_norm, position_ids) = match cfg.variant {
+ ModelVariant::Large => (
+ candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?,
+ None,
+ None,
+ None,
+ ),
+ ModelVariant::Small => {
+ let vb = vb.pp("embeddings");
+ let weight = vb.pp("LayerNorm").get_with_hints(
+ cfg.hidden_size,
+ "weight",
+ candle_nn::Init::Const(1.0),
+ )?;
+ let bias = vb.pp("LayerNorm").get_with_hints(
+ cfg.hidden_size,
+ "bias",
+ candle_nn::Init::Const(0.0),
+ )?;
+ let dev = bias.device().clone();
+
+ let layer_norm = candle_nn::LayerNorm::new(weight, bias, cfg.norm_eps);
+
+ (
+ candle_nn::embedding(
+ cfg.vocab_size,
+ cfg.hidden_size,
+ vb.pp("word_embeddings"),
+ )?,
+ Some(candle_nn::embedding(
+ cfg.type_vocab_size,
+ cfg.hidden_size,
+ vb.pp("token_type_embeddings"),
+ )?),
+ Some(layer_norm),
+ Some(Tensor::arange(
+ 0u32,
+ cfg.max_position_embeddings as u32,
+ &dev,
+ )?),
+ )
+ }
+ };
+
+ Ok(Self {
+ variant: cfg.variant,
+ embeddings,
+ token_type_embeddings,
+ layer_norm,
+ position_ids,
+ })
+ }
+}
+
+impl Module for Embeddings {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let embd = self.embeddings.forward(xs)?;
+ // For 1.5B just forward the embeddings
+ if self.variant == ModelVariant::Large {
+ return Ok(embd);
+ }
+
+ let (token_type_embed, layer_norm, pos_ids) =
+ if let (Some(token_type_embd), Some(layer_norm), Some(position_ids)) = (
+ &self.token_type_embeddings,
+ &self.layer_norm,
+ &self.position_ids,
+ ) {
+ (token_type_embd, layer_norm, position_ids)
+ } else {
+ return Err(Error::Msg(
+ "Stella 400M requires `token_type_embeddings`, `layer_norm` and `position_ids`"
+ .to_string(),
+ ));
+ };
+
+ let (batch_size, seq_length) = xs.dims2()?;
+
+ let pos_ids = pos_ids
+ .as_ref()
+ .narrow(0, 0, seq_length)?
+ .expand((batch_size, seq_length))?;
+
+ layer_norm.forward(&embd.add(&token_type_embed.forward(&pos_ids.zeros_like()?)?)?)
}
}
#[derive(Debug, Clone)]
pub struct Model {
- embed_tokens: candle_nn::Embedding,
- layers: Vec<DecoderLayer>,
- norm: RmsNorm,
+ embeddings: Embeddings,
+ layers: Vec<Layer>,
+ norm: Option<RmsNorm>,
device: Device,
dtype: DType,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
- let vb_m = vb.pp("model");
- let embed_tokens =
- candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
+ let vb_m = match cfg.variant {
+ ModelVariant::Large => vb.pp("model"),
+ ModelVariant::Small => vb.pp("new"),
+ };
+ // let embed_tokens =
+ // candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
+ let embeddings = Embeddings::new(cfg, vb_m.clone())?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
- let vb_l = vb_m.pp("layers");
+ let vb_l = match cfg.variant {
+ ModelVariant::Large => vb_m.pp("layers"),
+ ModelVariant::Small => vb_m.pp("encoder").pp("layer"),
+ };
for layer_idx in 0..cfg.num_hidden_layers {
- let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
+ let layer = Layer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
- let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
+ let norm = match cfg.variant {
+ ModelVariant::Large => Some(RmsNorm::new(
+ cfg.hidden_size,
+ cfg.norm_eps,
+ vb_m.pp("norm"),
+ )?),
+ ModelVariant::Small => None,
+ };
Ok(Self {
- embed_tokens,
+ embeddings,
layers,
norm,
- // sliding_window: 0,
device: vb.device().clone(),
dtype: vb.dtype(),
})
@@ -352,15 +742,20 @@ impl Model {
Some(self.prepare_attention_mask(mask)?)
};
- let mut xs = self.embed_tokens.forward(input_ids)?;
+ let mut xs = self.embeddings.forward(input_ids)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref())?
}
- xs.apply(&self.norm)
+
+ if let Some(n) = &self.norm {
+ xs.apply(n)
+ } else {
+ Ok(xs)
+ }
}
}
-#[derive(Debug, Clone)]
+#[derive(Debug)]
pub struct EmbeddingModel {
base_model: Model,
lm_head: Linear,