summaryrefslogtreecommitdiff
path: root/candle-examples/examples/quantized-phi
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/quantized-phi')
-rw-r--r--candle-examples/examples/quantized-phi/main.rs18
1 files changed, 15 insertions, 3 deletions
diff --git a/candle-examples/examples/quantized-phi/main.rs b/candle-examples/examples/quantized-phi/main.rs
index 7d255f58..e2211844 100644
--- a/candle-examples/examples/quantized-phi/main.rs
+++ b/candle-examples/examples/quantized-phi/main.rs
@@ -13,8 +13,9 @@ use candle::Tensor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_examples::token_output_stream::TokenOutputStream;
-use candle_transformers::models::quantized_llama::ModelWeights as Phi3;
+use candle_transformers::models::quantized_llama::ModelWeights as Phi3b;
use candle_transformers::models::quantized_phi::ModelWeights as Phi2;
+use candle_transformers::models::quantized_phi3::ModelWeights as Phi3;
const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. ";
@@ -24,6 +25,9 @@ enum Which {
Phi2,
#[value(name = "phi-3")]
Phi3,
+ /// Alternative implementation of phi-3, based on llama.
+ #[value(name = "phi-3b")]
+ Phi3b,
}
#[derive(Parser, Debug)]
@@ -84,7 +88,7 @@ struct Args {
repeat_last_n: usize,
/// The model size to use.
- #[arg(long, default_value = "phi-2")]
+ #[arg(long, default_value = "phi-3b")]
which: Which,
}
@@ -96,7 +100,7 @@ impl Args {
let api = hf_hub::api::sync::Api::new()?;
let repo = match self.which {
Which::Phi2 => "microsoft/phi-2",
- Which::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
+ Which::Phi3 | Which::Phi3b => "microsoft/Phi-3-mini-4k-instruct",
};
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
@@ -114,6 +118,11 @@ impl Args {
Which::Phi3 => (
"microsoft/Phi-3-mini-4k-instruct-gguf",
"Phi-3-mini-4k-instruct-q4.gguf",
+ "main",
+ ),
+ Which::Phi3b => (
+ "microsoft/Phi-3-mini-4k-instruct-gguf",
+ "Phi-3-mini-4k-instruct-q4.gguf",
"5eef2ce24766d31909c0b269fe90c817a8f263fb",
),
};
@@ -145,6 +154,7 @@ fn format_size(size_in_bytes: usize) -> String {
enum Model {
Phi2(Phi2),
Phi3(Phi3),
+ Phi3b(Phi3b),
}
impl Model {
@@ -152,6 +162,7 @@ impl Model {
match self {
Self::Phi2(m) => m.forward(xs, pos),
Self::Phi3(m) => m.forward(xs, pos),
+ Self::Phi3b(m) => m.forward(xs, pos),
}
}
}
@@ -203,6 +214,7 @@ fn main() -> anyhow::Result<()> {
match args.which {
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
Which::Phi3 => Model::Phi3(Phi3::from_gguf(model, &mut file, &device)?),
+ Which::Phi3b => Model::Phi3b(Phi3b::from_gguf(model, &mut file, &device)?),
}
};
println!("model built");