summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-31 21:31:38 +0100
committerGitHub <noreply@github.com>2023-07-31 21:31:38 +0100
commitf28558d0b7fa335e2772bd5b32f72b2e4e4c0ab1 (patch)
tree10b5fec49ce0c6b38fa5c725d34aa18ee8c4145d
parent6b98b66eb36a484f1a65fbc1c528a8e0b90a1419 (diff)
downloadcandle-f28558d0b7fa335e2772bd5b32f72b2e4e4c0ab1.tar.gz
candle-f28558d0b7fa335e2772bd5b32f72b2e4e4c0ab1.tar.bz2
candle-f28558d0b7fa335e2772bd5b32f72b2e4e4c0ab1.zip
Evaluate on the pre-tokenized file. (#290)
-rw-r--r--candle-examples/examples/llama2-c/main.rs59
1 files changed, 58 insertions, 1 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index d710652f..b627bd3d 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -214,6 +214,9 @@ struct Args {
#[arg(long, default_value = "")]
prompt: String,
+
+ #[arg(long)]
+ eval_file: Option<String>,
}
fn main() -> anyhow::Result<()> {
@@ -240,12 +243,66 @@ fn main() -> anyhow::Result<()> {
match args.task {
Task::Inference => run_inference(tokenizer, &config_path, args)?,
- Task::Evaluation => run_eval(tokenizer, &config_path, args)?,
+ Task::Evaluation => {
+ if let Some(eval_file) = &args.eval_file {
+ run_eval_file(eval_file.into(), &config_path, args)?
+ } else {
+ run_eval(tokenizer, &config_path, args)?
+ }
+ }
Task::Training => todo!(),
}
Ok(())
}
+fn run_eval_file(
+ path: std::path::PathBuf,
+ config_path: &std::path::PathBuf,
+ args: Args,
+) -> Result<()> {
+ use std::io::BufRead;
+
+ let device = candle_examples::device(args.cpu)?;
+ let mut file = std::fs::File::open(config_path)?;
+ let config = Config::from_reader(&mut file)?;
+ let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
+ let vb = weights.var_builder(&config, &device)?;
+ let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
+ let model = Llama::load(vb, &cache, config)?;
+
+ let bytes = std::fs::read(path)?;
+ // Tokens are encoded as u16.
+ let mut tokens = vec![0u16; bytes.len() / 2];
+ std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens)?;
+ let tokens: Vec<u32> = tokens.into_iter().map(|u| u as u32).collect();
+ println!("dataset loaded: {} tokens", tokens.len());
+
+ let seq_len = model.config.seq_len;
+ let batch_size = 32;
+ let mut inputs = vec![];
+ let mut targets = vec![];
+ for start_idx in (0..tokens.len()).step_by(seq_len) {
+ if start_idx + seq_len + 1 > tokens.len() {
+ break;
+ }
+ let tokens = &tokens[start_idx..start_idx + seq_len + 1];
+ let inputs_ = Tensor::new(&tokens[..seq_len], &device)?;
+ let targets_ = Tensor::new(&tokens[1..], &device)?;
+ inputs.push(inputs_);
+ targets.push(targets_);
+ if inputs.len() >= batch_size {
+ let inp = Tensor::stack(&inputs, 0)?;
+ let tgt = Tensor::stack(&targets, 0)?;
+ let logits = model.forward(&inp, 0)?;
+ let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
+ println!("{}", loss.to_vec0::<f32>()?);
+ inputs.clear();
+ targets.clear();
+ }
+ }
+ Ok(())
+}
+
fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args) -> Result<()> {
use std::io::BufRead;