summaryrefslogtreecommitdiff
path: root/candle-examples/examples/phi/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/phi/main.rs')
-rw-r--r--candle-examples/examples/phi/main.rs49
1 files changed, 33 insertions, 16 deletions
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs
index 720a4441..52d453b5 100644
--- a/candle-examples/examples/phi/main.rs
+++ b/candle-examples/examples/phi/main.rs
@@ -123,6 +123,8 @@ enum WhichModel {
V1,
#[value(name = "1.5")]
V1_5,
+ #[value(name = "2")]
+ V2,
PuffinPhiV2,
PhiHermes,
}
@@ -158,7 +160,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
- #[arg(long, short = 'n', default_value_t = 100)]
+ #[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,
#[arg(long)]
@@ -225,6 +227,7 @@ fn main() -> Result<()> {
match args.model {
WhichModel::V1 => "microsoft/phi-1".to_string(),
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
+ WhichModel::V2 => "microsoft/phi-2".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"lmz/candle-quantized-phi".to_string()
}
@@ -241,7 +244,9 @@ fn main() -> Result<()> {
match args.model {
WhichModel::V1 => "refs/pr/2".to_string(),
WhichModel::V1_5 => "refs/pr/18".to_string(),
- WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(),
+ WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
+ "main".to_string()
+ }
}
}
}
@@ -250,27 +255,32 @@ fn main() -> Result<()> {
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => match args.model {
- WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?,
+ WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => repo.get("tokenizer.json")?,
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
repo.get("tokenizer-puffin-phi-v2.json")?
}
},
};
- let filename = match args.weight_file {
- Some(weight_file) => std::path::PathBuf::from(weight_file),
+ let filenames = match args.weight_file {
+ Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => {
if args.quantized {
match args.model {
- WhichModel::V1 => repo.get("model-v1-q4k.gguf")?,
- WhichModel::V1_5 => repo.get("model-q4k.gguf")?,
- WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?,
- WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B-q4k.gguf")?,
+ WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
+ WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
+ WhichModel::V2 => vec![repo.get("model-v2-q4k.gguf")?],
+ WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
+ WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
}
} else {
match args.model {
- WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?,
- WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?,
- WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B.safetensors")?,
+ WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
+ WhichModel::V2 => vec![
+ repo.get("model-00001-of-00002.safetensors")?,
+ repo.get("model-00002-of-00002.safetensors")?,
+ ],
+ WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
+ WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
}
}
}
@@ -282,17 +292,24 @@ fn main() -> Result<()> {
let config = match args.model {
WhichModel::V1 => Config::v1(),
WhichModel::V1_5 => Config::v1_5(),
+ WhichModel::V2 => Config::v2(),
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
};
let (model, device) = if args.quantized {
- let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
- let model = QMixFormer::new(&config, vb)?;
+ let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
+ let model = match args.model {
+ WhichModel::V2 => QMixFormer::new_v2(&config, vb)?,
+ _ => QMixFormer::new(&config, vb)?,
+ };
(Model::Quantized(model), Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
- let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
- let model = MixFormer::new(&config, vb)?;
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
+ let model = match args.model {
+ WhichModel::V2 => MixFormer::new_v2(&config, vb)?,
+ _ => MixFormer::new(&config, vb)?,
+ };
(Model::MixFormer(model), device)
};
println!("loaded the model in {:?}", start.elapsed());