summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJani Monoses <jani.monoses@gmail.com>2025-01-13 15:35:33 +0200
committerGitHub <noreply@github.com>2025-01-13 14:35:33 +0100
commitab7ff7081eab36958b82b98b89cee3eacf877111 (patch)
treea6768826a260a190bfe774fbbf954d6e85b1c5ae
parent461e8c1685e003bdddfd1e7d1aa5092786ca9df5 (diff)
downloadcandle-ab7ff7081eab36958b82b98b89cee3eacf877111.tar.gz
candle-ab7ff7081eab36958b82b98b89cee3eacf877111.tar.bz2
candle-ab7ff7081eab36958b82b98b89cee3eacf877111.zip
Fixes for running Phi-4 quantized. (#2714)
-rw-r--r--candle-examples/examples/quantized-phi/main.rs6
-rw-r--r--candle-transformers/src/models/quantized_phi3.rs2
2 files changed, 6 insertions, 2 deletions
diff --git a/candle-examples/examples/quantized-phi/main.rs b/candle-examples/examples/quantized-phi/main.rs
index f567ce2d..a776e989 100644
--- a/candle-examples/examples/quantized-phi/main.rs
+++ b/candle-examples/examples/quantized-phi/main.rs
@@ -28,6 +28,8 @@ enum Which {
/// Alternative implementation of phi-3, based on llama.
#[value(name = "phi-3b")]
Phi3b,
+ #[value(name = "phi-4")]
+ Phi4,
}
#[derive(Parser, Debug)]
@@ -104,6 +106,7 @@ impl Args {
let repo = match self.which {
Which::Phi2 => "microsoft/phi-2",
Which::Phi3 | Which::Phi3b => "microsoft/Phi-3-mini-4k-instruct",
+ Which::Phi4 => "microsoft/phi-4",
};
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
@@ -128,6 +131,7 @@ impl Args {
"Phi-3-mini-4k-instruct-q4.gguf",
"5eef2ce24766d31909c0b269fe90c817a8f263fb",
),
+ Which::Phi4 => ("microsoft/phi-4-gguf", "phi-4-q4.gguf", "main"),
};
let api = hf_hub::api::sync::Api::new()?;
api.repo(hf_hub::Repo::with_revision(
@@ -216,7 +220,7 @@ fn main() -> anyhow::Result<()> {
);
match args.which {
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
- Which::Phi3 => Model::Phi3(Phi3::from_gguf(
+ Which::Phi3 | Which::Phi4 => Model::Phi3(Phi3::from_gguf(
args.use_flash_attn,
model,
&mut file,
diff --git a/candle-transformers/src/models/quantized_phi3.rs b/candle-transformers/src/models/quantized_phi3.rs
index 51a75f38..1ceb48d1 100644
--- a/candle-transformers/src/models/quantized_phi3.rs
+++ b/candle-transformers/src/models/quantized_phi3.rs
@@ -127,7 +127,7 @@ impl LayerWeights {
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.transpose(1, 2)?;
let k = k
- .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
+ .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?