diff options
Diffstat (limited to 'candle-examples/examples/phi/main.rs')
-rw-r--r-- | candle-examples/examples/phi/main.rs | 117 |
1 files changed, 104 insertions, 13 deletions
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 52d453b5..3574b1f2 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -145,7 +145,10 @@ struct Args { verbose_prompt: bool, #[arg(long)] - prompt: String, + prompt: Option<String>, + + #[arg(long)] + mmlu_dir: Option<String>, /// The temperature used to generate samples. #[arg(long)] @@ -314,17 +317,105 @@ fn main() -> Result<()> { }; println!("loaded the model in {:?}", start.elapsed()); - let mut pipeline = TextGeneration::new( - model, - tokenizer, - args.seed, - args.temperature, - args.top_p, - args.repeat_penalty, - args.repeat_last_n, - args.verbose_prompt, - &device, - ); - pipeline.run(&args.prompt, args.sample_len)?; + match (args.prompt, args.mmlu_dir) { + (None, None) | (Some(_), Some(_)) => { + anyhow::bail!("exactly one of --prompt and --mmlu-dir must be specified") + } + (Some(prompt), None) => { + let mut pipeline = TextGeneration::new( + model, + tokenizer, + args.seed, + args.temperature, + args.top_p, + args.repeat_penalty, + args.repeat_last_n, + args.verbose_prompt, + &device, + ); + pipeline.run(&prompt, args.sample_len)?; + } + (None, Some(mmlu_dir)) => mmlu(model, tokenizer, &device, mmlu_dir)?, + } + Ok(()) +} + +fn mmlu<P: AsRef<std::path::Path>>( + mut model: Model, + tokenizer: Tokenizer, + device: &Device, + mmlu_dir: P, +) -> anyhow::Result<()> { + for dir_entry in mmlu_dir.as_ref().read_dir()?.flatten() { + let dir_entry = dir_entry.path(); + let theme = match dir_entry.file_stem().and_then(|v| v.to_str()) { + None => "".to_string(), + Some(v) => match v.strip_suffix("_test") { + None => v.replace('_', " "), + Some(v) => v.replace('_', " "), + }, + }; + if dir_entry.extension().as_ref().and_then(|v| v.to_str()) != Some("csv") { + continue; + } + println!("reading {dir_entry:?}"); + let dir_entry = std::fs::File::open(dir_entry)?; + let mut reader = csv::ReaderBuilder::new() + .has_headers(false) + .from_reader(dir_entry); + let token_a = tokenizer.token_to_id("A").unwrap(); + let token_b = tokenizer.token_to_id("B").unwrap(); + let token_c = tokenizer.token_to_id("C").unwrap(); + let token_d = tokenizer.token_to_id("D").unwrap(); + for row in reader.records() { + let row = match row { + Err(_) => continue, + Ok(row) => row, + }; + if row.len() < 5 { + continue; + } + let question = row.get(0).unwrap(); + let answer_a = row.get(1).unwrap(); + let answer_b = row.get(2).unwrap(); + let answer_c = row.get(3).unwrap(); + let answer_d = row.get(4).unwrap(); + let answer = row.get(5).unwrap(); + let prompt = format!( + "{} {theme}.\n{question}\nA. {answer_a}\nB. {answer_b}\nC. {answer_c}\nD. {answer_d}\nAnswer:\n", + "The following are multiple choice questions (with answers) about" + ); + let tokens = tokenizer.encode(prompt.as_str(), true).map_err(E::msg)?; + let tokens = tokens.get_ids().to_vec(); + let input = Tensor::new(tokens, device)?.unsqueeze(0)?; + let logits = match &mut model { + Model::MixFormer(m) => { + m.clear_kv_cache(); + m.forward(&input)? + } + Model::Quantized(m) => { + m.clear_kv_cache(); + m.forward(&input)? + } + }; + let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; + let logits_v: Vec<f32> = logits.to_vec1()?; + let pr_a = logits_v[token_a as usize]; + let pr_b = logits_v[token_b as usize]; + let pr_c = logits_v[token_c as usize]; + let pr_d = logits_v[token_d as usize]; + let model_answer = if pr_a > pr_b && pr_a > pr_c && pr_a > pr_d { + "A" + } else if pr_b > pr_c && pr_b > pr_d { + "B" + } else if pr_c > pr_d { + "C" + } else { + "D" + }; + + println!("{prompt}\n -> {model_answer} vs {answer}"); + } + } Ok(()) } |