summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/Cargo.toml1
-rw-r--r--candle-examples/examples/phi/main.rs117
2 files changed, 105 insertions, 13 deletions
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index 0c4bf20e..8ae828bd 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -28,6 +28,7 @@ safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
+csv = "1.3.0"
[dev-dependencies]
anyhow = { workspace = true }
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs
index 52d453b5..3574b1f2 100644
--- a/candle-examples/examples/phi/main.rs
+++ b/candle-examples/examples/phi/main.rs
@@ -145,7 +145,10 @@ struct Args {
verbose_prompt: bool,
#[arg(long)]
- prompt: String,
+ prompt: Option<String>,
+
+ #[arg(long)]
+ mmlu_dir: Option<String>,
/// The temperature used to generate samples.
#[arg(long)]
@@ -314,17 +317,105 @@ fn main() -> Result<()> {
};
println!("loaded the model in {:?}", start.elapsed());
- let mut pipeline = TextGeneration::new(
- model,
- tokenizer,
- args.seed,
- args.temperature,
- args.top_p,
- args.repeat_penalty,
- args.repeat_last_n,
- args.verbose_prompt,
- &device,
- );
- pipeline.run(&args.prompt, args.sample_len)?;
+ match (args.prompt, args.mmlu_dir) {
+ (None, None) | (Some(_), Some(_)) => {
+ anyhow::bail!("exactly one of --prompt and --mmlu-dir must be specified")
+ }
+ (Some(prompt), None) => {
+ let mut pipeline = TextGeneration::new(
+ model,
+ tokenizer,
+ args.seed,
+ args.temperature,
+ args.top_p,
+ args.repeat_penalty,
+ args.repeat_last_n,
+ args.verbose_prompt,
+ &device,
+ );
+ pipeline.run(&prompt, args.sample_len)?;
+ }
+ (None, Some(mmlu_dir)) => mmlu(model, tokenizer, &device, mmlu_dir)?,
+ }
+ Ok(())
+}
+
+fn mmlu<P: AsRef<std::path::Path>>(
+ mut model: Model,
+ tokenizer: Tokenizer,
+ device: &Device,
+ mmlu_dir: P,
+) -> anyhow::Result<()> {
+ for dir_entry in mmlu_dir.as_ref().read_dir()?.flatten() {
+ let dir_entry = dir_entry.path();
+ let theme = match dir_entry.file_stem().and_then(|v| v.to_str()) {
+ None => "".to_string(),
+ Some(v) => match v.strip_suffix("_test") {
+ None => v.replace('_', " "),
+ Some(v) => v.replace('_', " "),
+ },
+ };
+ if dir_entry.extension().as_ref().and_then(|v| v.to_str()) != Some("csv") {
+ continue;
+ }
+ println!("reading {dir_entry:?}");
+ let dir_entry = std::fs::File::open(dir_entry)?;
+ let mut reader = csv::ReaderBuilder::new()
+ .has_headers(false)
+ .from_reader(dir_entry);
+ let token_a = tokenizer.token_to_id("A").unwrap();
+ let token_b = tokenizer.token_to_id("B").unwrap();
+ let token_c = tokenizer.token_to_id("C").unwrap();
+ let token_d = tokenizer.token_to_id("D").unwrap();
+ for row in reader.records() {
+ let row = match row {
+ Err(_) => continue,
+ Ok(row) => row,
+ };
+ if row.len() < 5 {
+ continue;
+ }
+ let question = row.get(0).unwrap();
+ let answer_a = row.get(1).unwrap();
+ let answer_b = row.get(2).unwrap();
+ let answer_c = row.get(3).unwrap();
+ let answer_d = row.get(4).unwrap();
+ let answer = row.get(5).unwrap();
+ let prompt = format!(
+ "{} {theme}.\n{question}\nA. {answer_a}\nB. {answer_b}\nC. {answer_c}\nD. {answer_d}\nAnswer:\n",
+ "The following are multiple choice questions (with answers) about"
+ );
+ let tokens = tokenizer.encode(prompt.as_str(), true).map_err(E::msg)?;
+ let tokens = tokens.get_ids().to_vec();
+ let input = Tensor::new(tokens, device)?.unsqueeze(0)?;
+ let logits = match &mut model {
+ Model::MixFormer(m) => {
+ m.clear_kv_cache();
+ m.forward(&input)?
+ }
+ Model::Quantized(m) => {
+ m.clear_kv_cache();
+ m.forward(&input)?
+ }
+ };
+ let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
+ let logits_v: Vec<f32> = logits.to_vec1()?;
+ let pr_a = logits_v[token_a as usize];
+ let pr_b = logits_v[token_b as usize];
+ let pr_c = logits_v[token_c as usize];
+ let pr_d = logits_v[token_d as usize];
+ let model_answer = if pr_a > pr_b && pr_a > pr_c && pr_a > pr_d {
+ "A"
+ } else if pr_b > pr_c && pr_b > pr_d {
+ "B"
+ } else if pr_c > pr_d {
+ "C"
+ } else {
+ "D"
+ };
+
+ println!("{prompt}\n -> {model_answer} vs {answer}");
+ }
+ }
Ok(())
}