summaryrefslogtreecommitdiff
path: root/candle-examples/examples/phi/main.rs
diff options
context:
space:
mode:
authorIvar Flakstad <69173633+ivarflakstad@users.noreply.github.com>2024-01-14 17:10:54 +0100
committerIvar Flakstad <69173633+ivarflakstad@users.noreply.github.com>2024-01-14 17:10:54 +0100
commitecf88a6d381e40c8db1c643dff2753fd877fae92 (patch)
tree9b6db2ec9a37a48185f323ab4c5e8b0baaa20221 /candle-examples/examples/phi/main.rs
parente06e8d0dbea3a052195f4ca27fb5ddcdbf1cd30c (diff)
parente6d86b081980196745e5f0b0eda8ce5334c0ff67 (diff)
downloadcandle-ecf88a6d381e40c8db1c643dff2753fd877fae92.tar.gz
candle-ecf88a6d381e40c8db1c643dff2753fd877fae92.tar.bz2
candle-ecf88a6d381e40c8db1c643dff2753fd877fae92.zip
Merge branch 'main' into ivarflakstad/metal-prng
Diffstat (limited to 'candle-examples/examples/phi/main.rs')
-rw-r--r--candle-examples/examples/phi/main.rs47
1 files changed, 35 insertions, 12 deletions
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs
index c529867b..ea99c706 100644
--- a/candle-examples/examples/phi/main.rs
+++ b/candle-examples/examples/phi/main.rs
@@ -8,6 +8,7 @@ use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
+use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi};
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
use candle::{DType, Device, Tensor};
@@ -18,6 +19,7 @@ use tokenizers::Tokenizer;
enum Model {
MixFormer(MixFormer),
+ Phi(Phi),
Quantized(QMixFormer),
}
@@ -84,6 +86,7 @@ impl TextGeneration {
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = match &mut self.model {
Model::MixFormer(m) => m.forward(&input)?,
+ Model::Phi(m) => m.forward(&input)?,
Model::Quantized(m) => m.forward(&input)?,
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
@@ -117,7 +120,7 @@ impl TextGeneration {
}
}
-#[derive(Clone, Copy, Debug, ValueEnum)]
+#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
enum WhichModel {
#[value(name = "1")]
V1,
@@ -125,6 +128,9 @@ enum WhichModel {
V1_5,
#[value(name = "2")]
V2,
+ // TODO: Make this the default once it has been battle tested.
+ #[value(name = "2-new")]
+ V2New,
PuffinPhiV2,
PhiHermes,
}
@@ -169,7 +175,7 @@ struct Args {
#[arg(long)]
model_id: Option<String>,
- #[arg(long, default_value = "1.5")]
+ #[arg(long, default_value = "2")]
model: WhichModel,
#[arg(long)]
@@ -230,7 +236,7 @@ fn main() -> Result<()> {
match args.model {
WhichModel::V1 => "microsoft/phi-1".to_string(),
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
- WhichModel::V2 => "microsoft/phi-2".to_string(),
+ WhichModel::V2 | WhichModel::V2New => "microsoft/phi-2".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"lmz/candle-quantized-phi".to_string()
}
@@ -247,7 +253,8 @@ fn main() -> Result<()> {
match args.model {
WhichModel::V1 => "refs/pr/2".to_string(),
WhichModel::V1_5 => "refs/pr/18".to_string(),
- WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
+ WhichModel::V2 => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
+ WhichModel::V2New | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"main".to_string()
}
}
@@ -258,7 +265,9 @@ fn main() -> Result<()> {
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => match args.model {
- WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => repo.get("tokenizer.json")?,
+ WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2New => {
+ repo.get("tokenizer.json")?
+ }
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
repo.get("tokenizer-puffin-phi-v2.json")?
}
@@ -271,14 +280,14 @@ fn main() -> Result<()> {
match args.model {
WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
- WhichModel::V2 => vec![repo.get("model-v2-q4k.gguf")?],
+ WhichModel::V2 | WhichModel::V2New => vec![repo.get("model-v2-q4k.gguf")?],
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
}
} else {
match args.model {
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
- WhichModel::V2 => candle_examples::hub_load_safetensors(
+ WhichModel::V2 | WhichModel::V2New => candle_examples::hub_load_safetensors(
&repo,
"model.safetensors.index.json",
)?,
@@ -292,25 +301,35 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
- let config = match args.model {
+ let config = || match args.model {
WhichModel::V1 => Config::v1(),
WhichModel::V1_5 => Config::v1_5(),
- WhichModel::V2 => Config::v2(),
+ WhichModel::V2 | WhichModel::V2New => Config::v2(),
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
};
- let (model, device) = if args.quantized {
+ let (model, device) = if args.model == WhichModel::V2New {
+ let device = candle_examples::device(args.cpu)?;
+ let config_filename = repo.get("config.json")?;
+ let config = std::fs::read_to_string(config_filename)?;
+ let config: PhiConfig = serde_json::from_str(&config)?;
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
+ let phi = Phi::new(&config, vb)?;
+ (Model::Phi(phi), device)
+ } else if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
+ let config = config();
let model = match args.model {
- WhichModel::V2 => QMixFormer::new_v2(&config, vb)?,
+ WhichModel::V2 | WhichModel::V2New => QMixFormer::new_v2(&config, vb)?,
_ => QMixFormer::new(&config, vb)?,
};
(Model::Quantized(model), Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
+ let config = config();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = match args.model {
- WhichModel::V2 => MixFormer::new_v2(&config, vb)?,
+ WhichModel::V2 | WhichModel::V2New => MixFormer::new_v2(&config, vb)?,
_ => MixFormer::new(&config, vb)?,
};
(Model::MixFormer(model), device)
@@ -393,6 +412,10 @@ fn mmlu<P: AsRef<std::path::Path>>(
m.clear_kv_cache();
m.forward(&input)?
}
+ Model::Phi(m) => {
+ m.clear_kv_cache();
+ m.forward(&input)?
+ }
Model::Quantized(m) => {
m.clear_kv_cache();
m.forward(&input)?