diff options
author | iskng <147113485+iskng@users.noreply.github.com> | 2024-11-29 00:01:08 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-11-29 09:01:08 +0100 |
commit | 4f59ed38b08b84ed9c52e53f2692a2fc1888f30b (patch) | |
tree | 634823b41d3a96309691177788a6fac29d19c7a3 /candle-examples/examples | |
parent | 54e7fc3c97a6d40e459cee4d4bf2eff5c82390da (diff) | |
download | candle-4f59ed38b08b84ed9c52e53f2692a2fc1888f30b.tar.gz candle-4f59ed38b08b84ed9c52e53f2692a2fc1888f30b.tar.bz2 candle-4f59ed38b08b84ed9c52e53f2692a2fc1888f30b.zip |
Adds support for stella_en_v5 embedding model -400M variant (#2608)
* Adds support for stella_en_v5 embedding model -400M variant
* Unified stella
* WIP: Unified Stella
* Combined stella for both 1.5B and 400M variants
* Cargo fmt for the CI
* removed redundant stella-400m model and example after merge into stella-en-v5
* cargo fmt --all
---------
Co-authored-by: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/stella-en-v5/README.md | 24 | ||||
-rw-r--r-- | candle-examples/examples/stella-en-v5/main.rs | 74 |
2 files changed, 73 insertions, 25 deletions
diff --git a/candle-examples/examples/stella-en-v5/README.md b/candle-examples/examples/stella-en-v5/README.md index 5fcc67c3..3a87b295 100644 --- a/candle-examples/examples/stella-en-v5/README.md +++ b/candle-examples/examples/stella-en-v5/README.md @@ -21,7 +21,7 @@ Stella_en_1.5B_v5 is trained by [MRL](https://arxiv.org/abs/2205.13147) enabling The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) for a retrieval task (s2p). The sample queries and docs are hardcoded in the example. ```bash -$ cargo run --example stella-en-v5 --release --features <metal | cuda> +$ cargo run --example stella-en-v5 --release --features <metal | cuda> -- --which 1.5b > > Score: 0.8178786 @@ -37,9 +37,29 @@ $ cargo run --example stella-en-v5 --release --features <metal | cuda> > caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types > > of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties. > + +$ cargo run --example stella-en-v5 --release --features <metal | cuda> -- --which 400m + +> +> Score: 0.8397539 +> Query: What are some ways to reduce stress? +> Answer: There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending +> time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent +> stress from building up. +> +> +> +> Score: 0.809545 +> Query: What are the benefits of drinking green tea? +> Answer: Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage +> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types +> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties. +> ``` ## Supported options: -- `Stella_en_15B_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`. +- `Stella_en_v5` has 2 model variants published - a 1.5B variant and 400M variant. This is enabled through the flag `--which`. E.g. `--which 400m` or `--which 1.5b`. + +- `Stella_en_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`. - As per the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5), the model has been primarily trained on `s2s` (similarity) and `s2p` (retrieval) tasks. These require a slightly different `query` preprocessing (a different prompt template for each). In this example this is enabled though `--task` option.
\ No newline at end of file diff --git a/candle-examples/examples/stella-en-v5/main.rs b/candle-examples/examples/stella-en-v5/main.rs index 2408262b..68ed7e70 100644 --- a/candle-examples/examples/stella-en-v5/main.rs +++ b/candle-examples/examples/stella-en-v5/main.rs @@ -212,6 +212,14 @@ impl EncodeTask { } } +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +enum Which { + #[value(name = "1.5b")] + Large, + #[value(name = "400m")] + Small, +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -219,6 +227,9 @@ struct Args { #[arg(long)] cpu: bool, + #[arg(long)] + which: Which, + /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, @@ -250,24 +261,33 @@ struct Args { // Tokenizer creation is super critical in our case. // We are going to be `padding: Left` for each batch -fn create_tokenizer(tokenizer_file: &Path) -> Result<Tokenizer> { +fn create_tokenizer(tokenizer_file: &Path, which: Which) -> Result<Tokenizer> { let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?; - let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") { - pad_id - } else { - return Err(anyhow!( - "Tokenizer doesn't contain expected `<|endoftext|>` token" - )); - }; - // This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding - tokenizer.with_padding(Some(PaddingParams { - strategy: PaddingStrategy::BatchLongest, - direction: PaddingDirection::Left, - pad_id, - pad_token: "<|endoftext|>".to_string(), - ..Default::default() - })); + if which == Which::Large { + let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") { + pad_id + } else { + return Err(anyhow!( + "Tokenizer doesn't contain expected `<|endoftext|>` token" + )); + }; + + // This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding + tokenizer.with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Left, + pad_id, + pad_token: "<|endoftext|>".to_string(), + ..Default::default() + })); + } else { + tokenizer.with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Right, + ..Default::default() + })); + } Ok(tokenizer) } @@ -298,7 +318,19 @@ fn main() -> Result<()> { Some(d) => d, None => EmbedDim::Dim1024, }; - let repo = api.repo(Repo::model("dunzhang/stella_en_1.5B_v5".to_string())); + + let (repo, cfg) = match args.which { + Which::Large => ( + "dunzhang/stella_en_1.5B_v5", + Config::new_1_5_b_v5(embed_dim.embed_dim()), + ), + Which::Small => ( + "dunzhang/stella_en_400M_v5", + Config::new_400_m_v5(embed_dim.embed_dim()), + ), + }; + + let repo = api.repo(Repo::model(repo.to_string())); let tokenizer_filename = match args.tokenizer_file { Some(file) => std::path::PathBuf::from(file), None => repo.get("tokenizer.json")?, @@ -330,7 +362,7 @@ fn main() -> Result<()> { println!("retrieved the files in {:?}", start.elapsed()); // Initializing the tokenizer which would require us to add padding to the `left` for batch encoding - let tokenizer = create_tokenizer(tokenizer_filename.as_path())?; + let tokenizer = create_tokenizer(tokenizer_filename.as_path(), args.which)?; let start = std::time::Instant::now(); @@ -343,11 +375,7 @@ fn main() -> Result<()> { let embed_vb = unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? }; - let model = EmbeddingModel::new( - &Config::new_1_5_b_v5(embed_dim.embed_dim()), - base_vb, - embed_vb, - )?; + let model = EmbeddingModel::new(&cfg, base_vb, embed_vb)?; println!("loaded the model in {:?}", start.elapsed()); |