summaryrefslogtreecommitdiff
path: root/candle-examples
diff options
context:
space:
mode:
authorJani Monoses <jani.monoses@gmail.com>2024-05-23 14:33:17 +0300
committerGitHub <noreply@github.com>2024-05-23 13:33:17 +0200
commit77ea479a1847d909ca5e4f27a36f5c8e302cd529 (patch)
tree81ef65034e5429508157f5f9ad3fbfd8bb698f39 /candle-examples
parent72e7ca529a3c243bef844f822a9668eaf8e36807 (diff)
downloadcandle-77ea479a1847d909ca5e4f27a36f5c8e302cd529.tar.gz
candle-77ea479a1847d909ca5e4f27a36f5c8e302cd529.tar.bz2
candle-77ea479a1847d909ca5e4f27a36f5c8e302cd529.zip
Add Phi-3 Medium (#2205)
Diffstat (limited to 'candle-examples')
-rw-r--r--candle-examples/examples/phi/main.rs19
1 files changed, 13 insertions, 6 deletions
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs
index 371b389f..1cfeb443 100644
--- a/candle-examples/examples/phi/main.rs
+++ b/candle-examples/examples/phi/main.rs
@@ -141,6 +141,8 @@ enum WhichModel {
V2,
#[value(name = "3")]
V3,
+ #[value(name = "3-medium")]
+ V3Medium,
#[value(name = "2-old")]
V2Old,
PuffinPhiV2,
@@ -254,6 +256,7 @@ fn main() -> Result<()> {
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
WhichModel::V3 => "microsoft/Phi-3-mini-4k-instruct".to_string(),
+ WhichModel::V3Medium => "microsoft/Phi-3-medium-4k-instruct".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"lmz/candle-quantized-phi".to_string()
}
@@ -273,6 +276,7 @@ fn main() -> Result<()> {
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
WhichModel::V2
| WhichModel::V3
+ | WhichModel::V3Medium
| WhichModel::PuffinPhiV2
| WhichModel::PhiHermes => "main".to_string(),
}
@@ -287,7 +291,8 @@ fn main() -> Result<()> {
| WhichModel::V1_5
| WhichModel::V2
| WhichModel::V2Old
- | WhichModel::V3 => repo.get("tokenizer.json")?,
+ | WhichModel::V3
+ | WhichModel::V3Medium => repo.get("tokenizer.json")?,
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
repo.get("tokenizer-puffin-phi-v2.json")?
}
@@ -303,14 +308,14 @@ fn main() -> Result<()> {
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
- WhichModel::V3 => anyhow::bail!(
+ WhichModel::V3 | WhichModel::V3Medium => anyhow::bail!(
"use the quantized or quantized-phi examples for quantized phi-v3"
),
}
} else {
match args.model {
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
- WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 => {
+ WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 | WhichModel::V3Medium => {
candle_examples::hub_load_safetensors(
&repo,
"model.safetensors.index.json",
@@ -332,7 +337,7 @@ fn main() -> Result<()> {
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
- WhichModel::V3 => {
+ WhichModel::V3 | WhichModel::V3Medium => {
panic!("use the quantized or quantized-phi examples for quantized phi-v3")
}
};
@@ -352,7 +357,9 @@ fn main() -> Result<()> {
let dtype = match args.dtype {
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
None => {
- if args.model == WhichModel::V3 && device.is_cuda() {
+ if (args.model == WhichModel::V3 || args.model == WhichModel::V3Medium)
+ && device.is_cuda()
+ {
DType::BF16
} else {
DType::F32
@@ -368,7 +375,7 @@ fn main() -> Result<()> {
let phi = Phi::new(&config, vb)?;
Model::Phi(phi)
}
- WhichModel::V3 => {
+ WhichModel::V3 | WhichModel::V3Medium => {
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let config: Phi3Config = serde_json::from_str(&config)?;