diff options
Diffstat (limited to 'candle-examples/examples/quantized-phi')
-rw-r--r-- | candle-examples/examples/quantized-phi/main.rs | 18 |
1 files changed, 15 insertions, 3 deletions
diff --git a/candle-examples/examples/quantized-phi/main.rs b/candle-examples/examples/quantized-phi/main.rs index 7d255f58..e2211844 100644 --- a/candle-examples/examples/quantized-phi/main.rs +++ b/candle-examples/examples/quantized-phi/main.rs @@ -13,8 +13,9 @@ use candle::Tensor; use candle_transformers::generation::{LogitsProcessor, Sampling}; use candle_examples::token_output_stream::TokenOutputStream; -use candle_transformers::models::quantized_llama::ModelWeights as Phi3; +use candle_transformers::models::quantized_llama::ModelWeights as Phi3b; use candle_transformers::models::quantized_phi::ModelWeights as Phi2; +use candle_transformers::models::quantized_phi3::ModelWeights as Phi3; const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. "; @@ -24,6 +25,9 @@ enum Which { Phi2, #[value(name = "phi-3")] Phi3, + /// Alternative implementation of phi-3, based on llama. + #[value(name = "phi-3b")] + Phi3b, } #[derive(Parser, Debug)] @@ -84,7 +88,7 @@ struct Args { repeat_last_n: usize, /// The model size to use. - #[arg(long, default_value = "phi-2")] + #[arg(long, default_value = "phi-3b")] which: Which, } @@ -96,7 +100,7 @@ impl Args { let api = hf_hub::api::sync::Api::new()?; let repo = match self.which { Which::Phi2 => "microsoft/phi-2", - Which::Phi3 => "microsoft/Phi-3-mini-4k-instruct", + Which::Phi3 | Which::Phi3b => "microsoft/Phi-3-mini-4k-instruct", }; let api = api.model(repo.to_string()); api.get("tokenizer.json")? @@ -114,6 +118,11 @@ impl Args { Which::Phi3 => ( "microsoft/Phi-3-mini-4k-instruct-gguf", "Phi-3-mini-4k-instruct-q4.gguf", + "main", + ), + Which::Phi3b => ( + "microsoft/Phi-3-mini-4k-instruct-gguf", + "Phi-3-mini-4k-instruct-q4.gguf", "5eef2ce24766d31909c0b269fe90c817a8f263fb", ), }; @@ -145,6 +154,7 @@ fn format_size(size_in_bytes: usize) -> String { enum Model { Phi2(Phi2), Phi3(Phi3), + Phi3b(Phi3b), } impl Model { @@ -152,6 +162,7 @@ impl Model { match self { Self::Phi2(m) => m.forward(xs, pos), Self::Phi3(m) => m.forward(xs, pos), + Self::Phi3b(m) => m.forward(xs, pos), } } } @@ -203,6 +214,7 @@ fn main() -> anyhow::Result<()> { match args.which { Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?), Which::Phi3 => Model::Phi3(Phi3::from_gguf(model, &mut file, &device)?), + Which::Phi3b => Model::Phi3b(Phi3b::from_gguf(model, &mut file, &device)?), } }; println!("model built"); |