diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-01-13 17:38:27 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-13 17:38:27 +0100 |
commit | 539ead927a12a485637f7f04f8212cfdabe00fa4 (patch) | |
tree | 567e9c869059e9cb93ecc9ef83ad252a105f6381 /candle-examples | |
parent | a46864bd5650c4707753f3d95d7b4ff6b0905995 (diff) | |
download | candle-539ead927a12a485637f7f04f8212cfdabe00fa4.tar.gz candle-539ead927a12a485637f7f04f8212cfdabe00fa4.tar.bz2 candle-539ead927a12a485637f7f04f8212cfdabe00fa4.zip |
Update the Phi model to use the updated architecture. (#1580)
* Update the Phi model to use the updated architecture.
* Add more of the phi model.
* Repeat KV + caching.
* Apply the rotary embeddings.
* Add support for the new phi model in the phi example.
* Fix a couple glitches.
* Fix a couple more glitches.
Diffstat (limited to 'candle-examples')
-rw-r--r-- | candle-examples/examples/phi/main.rs | 46 |
1 files changed, 35 insertions, 11 deletions
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index c5c7de28..ea99c706 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -8,6 +8,7 @@ use anyhow::{Error as E, Result}; use clap::{Parser, ValueEnum}; use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer}; +use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi}; use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer; use candle::{DType, Device, Tensor}; @@ -18,6 +19,7 @@ use tokenizers::Tokenizer; enum Model { MixFormer(MixFormer), + Phi(Phi), Quantized(QMixFormer), } @@ -84,6 +86,7 @@ impl TextGeneration { let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; let logits = match &mut self.model { Model::MixFormer(m) => m.forward(&input)?, + Model::Phi(m) => m.forward(&input)?, Model::Quantized(m) => m.forward(&input)?, }; let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; @@ -117,7 +120,7 @@ impl TextGeneration { } } -#[derive(Clone, Copy, Debug, ValueEnum)] +#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)] enum WhichModel { #[value(name = "1")] V1, @@ -125,6 +128,9 @@ enum WhichModel { V1_5, #[value(name = "2")] V2, + // TODO: Make this the default once it has been battle tested. + #[value(name = "2-new")] + V2New, PuffinPhiV2, PhiHermes, } @@ -230,7 +236,7 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => "microsoft/phi-1".to_string(), WhichModel::V1_5 => "microsoft/phi-1_5".to_string(), - WhichModel::V2 => "microsoft/phi-2".to_string(), + WhichModel::V2 | WhichModel::V2New => "microsoft/phi-2".to_string(), WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { "lmz/candle-quantized-phi".to_string() } @@ -248,7 +254,9 @@ fn main() -> Result<()> { WhichModel::V1 => "refs/pr/2".to_string(), WhichModel::V1_5 => "refs/pr/18".to_string(), WhichModel::V2 => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(), - WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(), + WhichModel::V2New | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { + "main".to_string() + } } } } @@ -257,7 +265,9 @@ fn main() -> Result<()> { let tokenizer_filename = match args.tokenizer { Some(file) => std::path::PathBuf::from(file), None => match args.model { - WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => repo.get("tokenizer.json")?, + WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2New => { + repo.get("tokenizer.json")? + } WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { repo.get("tokenizer-puffin-phi-v2.json")? } @@ -270,14 +280,14 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?], WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?], - WhichModel::V2 => vec![repo.get("model-v2-q4k.gguf")?], + WhichModel::V2 | WhichModel::V2New => vec![repo.get("model-v2-q4k.gguf")?], WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?], WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?], } } else { match args.model { WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?], - WhichModel::V2 => candle_examples::hub_load_safetensors( + WhichModel::V2 | WhichModel::V2New => candle_examples::hub_load_safetensors( &repo, "model.safetensors.index.json", )?, @@ -291,25 +301,35 @@ fn main() -> Result<()> { let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); - let config = match args.model { + let config = || match args.model { WhichModel::V1 => Config::v1(), WhichModel::V1_5 => Config::v1_5(), - WhichModel::V2 => Config::v2(), + WhichModel::V2 | WhichModel::V2New => Config::v2(), WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), WhichModel::PhiHermes => Config::phi_hermes_1_3b(), }; - let (model, device) = if args.quantized { + let (model, device) = if args.model == WhichModel::V2New { + let device = candle_examples::device(args.cpu)?; + let config_filename = repo.get("config.json")?; + let config = std::fs::read_to_string(config_filename)?; + let config: PhiConfig = serde_json::from_str(&config)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; + let phi = Phi::new(&config, vb)?; + (Model::Phi(phi), device) + } else if args.quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?; + let config = config(); let model = match args.model { - WhichModel::V2 => QMixFormer::new_v2(&config, vb)?, + WhichModel::V2 | WhichModel::V2New => QMixFormer::new_v2(&config, vb)?, _ => QMixFormer::new(&config, vb)?, }; (Model::Quantized(model), Device::Cpu) } else { let device = candle_examples::device(args.cpu)?; + let config = config(); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; let model = match args.model { - WhichModel::V2 => MixFormer::new_v2(&config, vb)?, + WhichModel::V2 | WhichModel::V2New => MixFormer::new_v2(&config, vb)?, _ => MixFormer::new(&config, vb)?, }; (Model::MixFormer(model), device) @@ -392,6 +412,10 @@ fn mmlu<P: AsRef<std::path::Path>>( m.clear_kv_cache(); m.forward(&input)? } + Model::Phi(m) => { + m.clear_kv_cache(); + m.forward(&input)? + } Model::Quantized(m) => { m.clear_kv_cache(); m.forward(&input)? |