summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/llama2-c/main.rs81
1 files changed, 51 insertions, 30 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index b627bd3d..ac17aab1 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -215,8 +215,13 @@ struct Args {
#[arg(long, default_value = "")]
prompt: String,
+ /// A directory with the pre-tokenized dataset in the format generated by the tinystories.py
+ /// script from llama2.c https://github.com/karpathy/llama2.c
#[arg(long)]
- eval_file: Option<String>,
+ pretokenized_dir: Option<String>,
+
+ #[arg(long, default_value_t = 32)]
+ batch_size: usize,
}
fn main() -> anyhow::Result<()> {
@@ -243,13 +248,7 @@ fn main() -> anyhow::Result<()> {
match args.task {
Task::Inference => run_inference(tokenizer, &config_path, args)?,
- Task::Evaluation => {
- if let Some(eval_file) = &args.eval_file {
- run_eval_file(eval_file.into(), &config_path, args)?
- } else {
- run_eval(tokenizer, &config_path, args)?
- }
- }
+ Task::Evaluation => run_eval(tokenizer, &config_path, args)?,
Task::Training => todo!(),
}
Ok(())
@@ -278,7 +277,6 @@ fn run_eval_file(
println!("dataset loaded: {} tokens", tokens.len());
let seq_len = model.config.seq_len;
- let batch_size = 32;
let mut inputs = vec![];
let mut targets = vec![];
for start_idx in (0..tokens.len()).step_by(seq_len) {
@@ -290,7 +288,7 @@ fn run_eval_file(
let targets_ = Tensor::new(&tokens[1..], &device)?;
inputs.push(inputs_);
targets.push(targets_);
- if inputs.len() >= batch_size {
+ if inputs.len() >= args.batch_size {
let inp = Tensor::stack(&inputs, 0)?;
let tgt = Tensor::stack(&targets, 0)?;
let logits = model.forward(&inp, 0)?;
@@ -314,32 +312,55 @@ fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args)
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
- let api = hf_hub::api::sync::Api::new()?;
- let model_id = "roneneldan/TinyStories"; // TODO: Make this configurable.
- println!("loading the evaluation dataset from {}", model_id);
- let api = api.dataset(model_id.to_string());
- let dataset_path = api.get("TinyStories-valid.txt")?;
- let file = std::fs::File::open(dataset_path)?;
- let file = std::io::BufReader::new(file);
- let mut tokens = vec![];
- for line in file.lines() {
- let line = line?.replace("<|endoftext|>", "");
- let line = tokenizer.encode(line, false).map_err(E::msg)?;
- tokens.push(line.get_ids().to_vec())
- }
- let tokens = tokens.concat();
+ let tokens = match args.pretokenized_dir {
+ None => {
+ let api = hf_hub::api::sync::Api::new()?;
+ let model_id = "roneneldan/TinyStories"; // TODO: Make this configurable.
+ println!("loading the evaluation dataset from {}", model_id);
+ let api = api.dataset(model_id.to_string());
+ let dataset_path = api.get("TinyStories-valid.txt")?;
+ let file = std::fs::File::open(dataset_path)?;
+ let file = std::io::BufReader::new(file);
+ let mut tokens = vec![];
+ for line in file.lines() {
+ let line = line?.replace("<|endoftext|>", "<s>");
+ let line = tokenizer.encode(line, false).map_err(E::msg)?;
+ tokens.push(line.get_ids().to_vec())
+ }
+ tokens.concat()
+ }
+ Some(pretokenized_dir) => {
+ let path = std::path::PathBuf::from(pretokenized_dir).join("data00.bin");
+ let bytes = std::fs::read(path)?;
+ // Tokens are encoded as u16.
+ let mut tokens = vec![0u16; bytes.len() / 2];
+ std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens)?;
+ tokens.into_iter().map(|u| u as u32).collect::<Vec<u32>>()
+ }
+ };
println!("dataset loaded and encoded: {} tokens", tokens.len());
- let seq_len = 256;
+
+ let seq_len = model.config.seq_len;
+ let mut inputs = vec![];
+ let mut targets = vec![];
for start_idx in (0..tokens.len()).step_by(seq_len) {
if start_idx + seq_len + 1 > tokens.len() {
break;
}
let tokens = &tokens[start_idx..start_idx + seq_len + 1];
- let inputs = Tensor::new(&tokens[..seq_len], &device)?.unsqueeze(0)?;
- let targets = Tensor::new(&tokens[1..], &device)?;
- let logits = model.forward(&inputs, 0)?.squeeze(0)?;
- let loss = candle_nn::loss::cross_entropy(&logits, &targets)?;
- println!("{start_idx} {}", loss.to_vec0::<f32>()?);
+ let inputs_ = Tensor::new(&tokens[..seq_len], &device)?;
+ let targets_ = Tensor::new(&tokens[1..], &device)?;
+ inputs.push(inputs_);
+ targets.push(targets_);
+ if inputs.len() >= args.batch_size {
+ let inp = Tensor::stack(&inputs, 0)?;
+ let tgt = Tensor::stack(&targets, 0)?;
+ let logits = model.forward(&inp, 0)?;
+ let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
+ println!("{}", loss.to_vec0::<f32>()?);
+ inputs.clear();
+ targets.clear();
+ }
}
Ok(())
}