summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama2-c/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama2-c/main.rs')
-rw-r--r--candle-examples/examples/llama2-c/main.rs62
1 files changed, 37 insertions, 25 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index 9b6d1316..20a6267c 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -8,7 +8,7 @@ extern crate intel_mkl_src;
mod model;
use clap::Parser;
-use anyhow::Result;
+use anyhow::{Error as E, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use candle::{DType, Device, Error, IndexOp, Layout, Shape, Tensor};
use candle_nn::{Embedding, Linear, VarBuilder};
@@ -181,38 +181,35 @@ struct Args {
/// Config file in binary format.
#[arg(long)]
- config: String,
+ config: Option<String>,
- /// Tokenizer config file in binary format.
+ /// Tokenizer config file.
#[arg(long)]
- tokenizer: String,
+ tokenizer: Option<String>,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
-}
-
-struct Tokenizer {
- tokens: Vec<String>,
-}
-impl Tokenizer {
- fn from_reader<R: std::io::Read>(r: &mut R, c: &Config) -> Result<Self> {
- let mut tokens = Vec::with_capacity(c.vocab_size);
- for _token_index in 0..c.vocab_size {
- let token_len = read_i32(r)?;
- let mut token = vec![0u8; token_len as usize];
- r.read_exact(&mut token);
- tokens.push(String::from_utf8_lossy(&token).into_owned())
- }
- Ok(Self { tokens })
- }
+ #[arg(long, default_value = "karpathy/tinyllamas")]
+ model_id: String,
}
fn main() -> anyhow::Result<()> {
+ use tokenizers::Tokenizer;
+
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
- let mut file = std::fs::File::open(&args.config)?;
+ let config_path = match &args.config {
+ Some(config) => std::path::PathBuf::from(config),
+ None => {
+ let api = hf_hub::api::sync::Api::new()?;
+ println!("loading the model weights from {}", args.model_id);
+ let api = api.model(args.model_id);
+ api.get("stories15M.bin")?
+ }
+ };
+ let mut file = std::fs::File::open(&config_path)?;
let config = Config::from_reader(&mut file)?;
println!("config: {config:?}");
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
@@ -220,8 +217,16 @@ fn main() -> anyhow::Result<()> {
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, &config)?;
- let mut file = std::fs::File::open(&args.tokenizer)?;
- let tokenizer = Tokenizer::from_reader(&mut file, &config)?;
+ let tokenizer_path = match &args.tokenizer {
+ Some(config) => std::path::PathBuf::from(config),
+ None => {
+ let api = hf_hub::api::sync::Api::new()?;
+ let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
+ api.get("tokenizer.json")?
+ }
+ };
+ println!("{tokenizer_path:?}");
+ let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
@@ -244,8 +249,15 @@ fn main() -> anyhow::Result<()> {
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
- print!("{}", tokenizer.tokens[next_token as usize]);
- std::io::stdout().flush()?;
+ // Extracting the last token as a string is complicated, here we just apply some simple
+ // heuristics as it seems to work well enough for this example. See the following for more
+ // details:
+ // https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
+ if let Some(text) = tokenizer.id_to_token(next_token) {
+ let text = text.replace('▁', " ").replace("<0x0A>", "\n");
+ print!("{text}");
+ std::io::stdout().flush()?;
+ }
}
let dt = start_gen.elapsed();
println!(