summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-27 05:57:08 +0100
committerGitHub <noreply@github.com>2023-10-27 05:57:08 +0100
commit70d06ab4b0065576e779a628fc024ef46003cdbc (patch)
tree763f43fad04b479397da8ee741a1981ebd303110
parent0ec5ebcec429fe2bb85a6a7f780509bb1831b024 (diff)
downloadcandle-70d06ab4b0065576e779a628fc024ef46003cdbc.tar.gz
candle-70d06ab4b0065576e779a628fc024ef46003cdbc.tar.bz2
candle-70d06ab4b0065576e779a628fc024ef46003cdbc.zip
Add support for the phi-hermes finetuned model. (#1192)
-rw-r--r--candle-examples/examples/phi/main.rs14
-rw-r--r--candle-transformers/src/models/mixformer.rs17
2 files changed, 28 insertions, 3 deletions
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs
index 9401299a..720a4441 100644
--- a/candle-examples/examples/phi/main.rs
+++ b/candle-examples/examples/phi/main.rs
@@ -124,6 +124,7 @@ enum WhichModel {
#[value(name = "1.5")]
V1_5,
PuffinPhiV2,
+ PhiHermes,
}
#[derive(Parser, Debug)]
@@ -224,7 +225,9 @@ fn main() -> Result<()> {
match args.model {
WhichModel::V1 => "microsoft/phi-1".to_string(),
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
- WhichModel::PuffinPhiV2 => "lmz/candle-quantized-phi".to_string(),
+ WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
+ "lmz/candle-quantized-phi".to_string()
+ }
}
}
}
@@ -238,7 +241,7 @@ fn main() -> Result<()> {
match args.model {
WhichModel::V1 => "refs/pr/2".to_string(),
WhichModel::V1_5 => "refs/pr/18".to_string(),
- WhichModel::PuffinPhiV2 => "main".to_string(),
+ WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(),
}
}
}
@@ -248,7 +251,9 @@ fn main() -> Result<()> {
Some(file) => std::path::PathBuf::from(file),
None => match args.model {
WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?,
- WhichModel::PuffinPhiV2 => repo.get("tokenizer-puffin-phi-v2.json")?,
+ WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
+ repo.get("tokenizer-puffin-phi-v2.json")?
+ }
},
};
let filename = match args.weight_file {
@@ -259,11 +264,13 @@ fn main() -> Result<()> {
WhichModel::V1 => repo.get("model-v1-q4k.gguf")?,
WhichModel::V1_5 => repo.get("model-q4k.gguf")?,
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?,
+ WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B-q4k.gguf")?,
}
} else {
match args.model {
WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?,
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?,
+ WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B.safetensors")?,
}
}
}
@@ -276,6 +283,7 @@ fn main() -> Result<()> {
WhichModel::V1 => Config::v1(),
WhichModel::V1_5 => Config::v1_5(),
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
+ WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
};
let (model, device) = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs
index 33aefbfe..e822ca14 100644
--- a/candle-transformers/src/models/mixformer.rs
+++ b/candle-transformers/src/models/mixformer.rs
@@ -73,6 +73,23 @@ impl Config {
pad_vocab_size_multiple: 64,
}
}
+
+ // https://huggingface.co/teknium/Phi-Hermes-1.3B/blob/main/config.json
+ pub fn phi_hermes_1_3b() -> Self {
+ Self {
+ vocab_size: 50304,
+ n_positions: 2048,
+ n_embd: 2048,
+ n_layer: 24,
+ n_inner: None,
+ n_head: 32,
+ rotary_dim: usize::min(32, 2048 / 32),
+ activation_function: Activation::NewGelu,
+ layer_norm_epsilon: 1e-5,
+ tie_word_embeddings: false,
+ pad_vocab_size_multiple: 64,
+ }
+ }
}
#[derive(Debug, Clone)]