diff options
author | Jani Monoses <jani.monoses@gmail.com> | 2025-01-13 15:35:33 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-13 14:35:33 +0100 |
commit | ab7ff7081eab36958b82b98b89cee3eacf877111 (patch) | |
tree | a6768826a260a190bfe774fbbf954d6e85b1c5ae | |
parent | 461e8c1685e003bdddfd1e7d1aa5092786ca9df5 (diff) | |
download | candle-ab7ff7081eab36958b82b98b89cee3eacf877111.tar.gz candle-ab7ff7081eab36958b82b98b89cee3eacf877111.tar.bz2 candle-ab7ff7081eab36958b82b98b89cee3eacf877111.zip |
Fixes for running Phi-4 quantized. (#2714)
-rw-r--r-- | candle-examples/examples/quantized-phi/main.rs | 6 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_phi3.rs | 2 |
2 files changed, 6 insertions, 2 deletions
diff --git a/candle-examples/examples/quantized-phi/main.rs b/candle-examples/examples/quantized-phi/main.rs index f567ce2d..a776e989 100644 --- a/candle-examples/examples/quantized-phi/main.rs +++ b/candle-examples/examples/quantized-phi/main.rs @@ -28,6 +28,8 @@ enum Which { /// Alternative implementation of phi-3, based on llama. #[value(name = "phi-3b")] Phi3b, + #[value(name = "phi-4")] + Phi4, } #[derive(Parser, Debug)] @@ -104,6 +106,7 @@ impl Args { let repo = match self.which { Which::Phi2 => "microsoft/phi-2", Which::Phi3 | Which::Phi3b => "microsoft/Phi-3-mini-4k-instruct", + Which::Phi4 => "microsoft/phi-4", }; let api = api.model(repo.to_string()); api.get("tokenizer.json")? @@ -128,6 +131,7 @@ impl Args { "Phi-3-mini-4k-instruct-q4.gguf", "5eef2ce24766d31909c0b269fe90c817a8f263fb", ), + Which::Phi4 => ("microsoft/phi-4-gguf", "phi-4-q4.gguf", "main"), }; let api = hf_hub::api::sync::Api::new()?; api.repo(hf_hub::Repo::with_revision( @@ -216,7 +220,7 @@ fn main() -> anyhow::Result<()> { ); match args.which { Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?), - Which::Phi3 => Model::Phi3(Phi3::from_gguf( + Which::Phi3 | Which::Phi4 => Model::Phi3(Phi3::from_gguf( args.use_flash_attn, model, &mut file, diff --git a/candle-transformers/src/models/quantized_phi3.rs b/candle-transformers/src/models/quantized_phi3.rs index 51a75f38..1ceb48d1 100644 --- a/candle-transformers/src/models/quantized_phi3.rs +++ b/candle-transformers/src/models/quantized_phi3.rs @@ -127,7 +127,7 @@ impl LayerWeights { .reshape((b_sz, seq_len, self.n_head, self.head_dim))? .transpose(1, 2)?; let k = k - .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? .transpose(1, 2)?; let v = v .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? |