diff options
author | Jani Monoses <jani.monoses@gmail.com> | 2024-05-23 14:33:17 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-23 13:33:17 +0200 |
commit | 77ea479a1847d909ca5e4f27a36f5c8e302cd529 (patch) | |
tree | 81ef65034e5429508157f5f9ad3fbfd8bb698f39 /candle-examples | |
parent | 72e7ca529a3c243bef844f822a9668eaf8e36807 (diff) | |
download | candle-77ea479a1847d909ca5e4f27a36f5c8e302cd529.tar.gz candle-77ea479a1847d909ca5e4f27a36f5c8e302cd529.tar.bz2 candle-77ea479a1847d909ca5e4f27a36f5c8e302cd529.zip |
Add Phi-3 Medium (#2205)
Diffstat (limited to 'candle-examples')
-rw-r--r-- | candle-examples/examples/phi/main.rs | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 371b389f..1cfeb443 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -141,6 +141,8 @@ enum WhichModel { V2, #[value(name = "3")] V3, + #[value(name = "3-medium")] + V3Medium, #[value(name = "2-old")] V2Old, PuffinPhiV2, @@ -254,6 +256,7 @@ fn main() -> Result<()> { WhichModel::V1_5 => "microsoft/phi-1_5".to_string(), WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(), WhichModel::V3 => "microsoft/Phi-3-mini-4k-instruct".to_string(), + WhichModel::V3Medium => "microsoft/Phi-3-medium-4k-instruct".to_string(), WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { "lmz/candle-quantized-phi".to_string() } @@ -273,6 +276,7 @@ fn main() -> Result<()> { WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(), WhichModel::V2 | WhichModel::V3 + | WhichModel::V3Medium | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(), } @@ -287,7 +291,8 @@ fn main() -> Result<()> { | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2Old - | WhichModel::V3 => repo.get("tokenizer.json")?, + | WhichModel::V3 + | WhichModel::V3Medium => repo.get("tokenizer.json")?, WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { repo.get("tokenizer-puffin-phi-v2.json")? } @@ -303,14 +308,14 @@ fn main() -> Result<()> { WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?], WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?], WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?], - WhichModel::V3 => anyhow::bail!( + WhichModel::V3 | WhichModel::V3Medium => anyhow::bail!( "use the quantized or quantized-phi examples for quantized phi-v3" ), } } else { match args.model { WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?], - WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 => { + WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 | WhichModel::V3Medium => { candle_examples::hub_load_safetensors( &repo, "model.safetensors.index.json", @@ -332,7 +337,7 @@ fn main() -> Result<()> { WhichModel::V2 | WhichModel::V2Old => Config::v2(), WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), WhichModel::PhiHermes => Config::phi_hermes_1_3b(), - WhichModel::V3 => { + WhichModel::V3 | WhichModel::V3Medium => { panic!("use the quantized or quantized-phi examples for quantized phi-v3") } }; @@ -352,7 +357,9 @@ fn main() -> Result<()> { let dtype = match args.dtype { Some(dtype) => std::str::FromStr::from_str(&dtype)?, None => { - if args.model == WhichModel::V3 && device.is_cuda() { + if (args.model == WhichModel::V3 || args.model == WhichModel::V3Medium) + && device.is_cuda() + { DType::BF16 } else { DType::F32 @@ -368,7 +375,7 @@ fn main() -> Result<()> { let phi = Phi::new(&config, vb)?; Model::Phi(phi) } - WhichModel::V3 => { + WhichModel::V3 | WhichModel::V3Medium => { let config_filename = repo.get("config.json")?; let config = std::fs::read_to_string(config_filename)?; let config: Phi3Config = serde_json::from_str(&config)?; |