summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stella-en-v5/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/stella-en-v5/main.rs')
-rw-r--r--candle-examples/examples/stella-en-v5/main.rs74
1 files changed, 51 insertions, 23 deletions
diff --git a/candle-examples/examples/stella-en-v5/main.rs b/candle-examples/examples/stella-en-v5/main.rs
index 2408262b..68ed7e70 100644
--- a/candle-examples/examples/stella-en-v5/main.rs
+++ b/candle-examples/examples/stella-en-v5/main.rs
@@ -212,6 +212,14 @@ impl EncodeTask {
}
}
+#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
+enum Which {
+ #[value(name = "1.5b")]
+ Large,
+ #[value(name = "400m")]
+ Small,
+}
+
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@@ -219,6 +227,9 @@ struct Args {
#[arg(long)]
cpu: bool,
+ #[arg(long)]
+ which: Which,
+
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
@@ -250,24 +261,33 @@ struct Args {
// Tokenizer creation is super critical in our case.
// We are going to be `padding: Left` for each batch
-fn create_tokenizer(tokenizer_file: &Path) -> Result<Tokenizer> {
+fn create_tokenizer(tokenizer_file: &Path, which: Which) -> Result<Tokenizer> {
let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
- let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") {
- pad_id
- } else {
- return Err(anyhow!(
- "Tokenizer doesn't contain expected `<|endoftext|>` token"
- ));
- };
- // This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding
- tokenizer.with_padding(Some(PaddingParams {
- strategy: PaddingStrategy::BatchLongest,
- direction: PaddingDirection::Left,
- pad_id,
- pad_token: "<|endoftext|>".to_string(),
- ..Default::default()
- }));
+ if which == Which::Large {
+ let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") {
+ pad_id
+ } else {
+ return Err(anyhow!(
+ "Tokenizer doesn't contain expected `<|endoftext|>` token"
+ ));
+ };
+
+ // This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding
+ tokenizer.with_padding(Some(PaddingParams {
+ strategy: PaddingStrategy::BatchLongest,
+ direction: PaddingDirection::Left,
+ pad_id,
+ pad_token: "<|endoftext|>".to_string(),
+ ..Default::default()
+ }));
+ } else {
+ tokenizer.with_padding(Some(PaddingParams {
+ strategy: PaddingStrategy::BatchLongest,
+ direction: PaddingDirection::Right,
+ ..Default::default()
+ }));
+ }
Ok(tokenizer)
}
@@ -298,7 +318,19 @@ fn main() -> Result<()> {
Some(d) => d,
None => EmbedDim::Dim1024,
};
- let repo = api.repo(Repo::model("dunzhang/stella_en_1.5B_v5".to_string()));
+
+ let (repo, cfg) = match args.which {
+ Which::Large => (
+ "dunzhang/stella_en_1.5B_v5",
+ Config::new_1_5_b_v5(embed_dim.embed_dim()),
+ ),
+ Which::Small => (
+ "dunzhang/stella_en_400M_v5",
+ Config::new_400_m_v5(embed_dim.embed_dim()),
+ ),
+ };
+
+ let repo = api.repo(Repo::model(repo.to_string()));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
@@ -330,7 +362,7 @@ fn main() -> Result<()> {
println!("retrieved the files in {:?}", start.elapsed());
// Initializing the tokenizer which would require us to add padding to the `left` for batch encoding
- let tokenizer = create_tokenizer(tokenizer_filename.as_path())?;
+ let tokenizer = create_tokenizer(tokenizer_filename.as_path(), args.which)?;
let start = std::time::Instant::now();
@@ -343,11 +375,7 @@ fn main() -> Result<()> {
let embed_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? };
- let model = EmbeddingModel::new(
- &Config::new_1_5_b_v5(embed_dim.embed_dim()),
- base_vb,
- embed_vb,
- )?;
+ let model = EmbeddingModel::new(&cfg, base_vb, embed_vb)?;
println!("loaded the model in {:?}", start.elapsed());