diff options
-rw-r--r-- | candle-examples/examples/phi/README.md | 30 | ||||
-rw-r--r-- | candle-examples/examples/phi/main.rs | 7 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_mixformer.rs | 18 |
3 files changed, 49 insertions, 6 deletions
diff --git a/candle-examples/examples/phi/README.md b/candle-examples/examples/phi/README.md index 566411d1..a84c01f2 100644 --- a/candle-examples/examples/phi/README.md +++ b/candle-examples/examples/phi/README.md @@ -1,14 +1,36 @@ -# candle-phi: 1.3b LLM with state of the art performance for <10b models. +# candle-phi: 1.3b and 2.7b LLM with state of the art performance for <10b models. -[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) is a language model using -only 1.3 billion parameters but with state of the art performance compared to +[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) and +[Phi-2](https://huggingface.co/microsoft/phi-2) are language models using +only 1.3 and 2.7 billion parameters but with state of the art performance compared to models with up to 10 billion parameters. The candle implementation provides both the standard version as well as a quantized variant. -## Running some example +## Running some examples +For the v2 version. +```bash +$ cargo run --example phi --release cuda -- --prompt "def print_prime(n): " --model 2 +def print_prime(n): + if n <= 1: + print("Not a prime number") + else: + for i in range(2, int(n**0.5)+1): + if (n % i) == 0: + print("Not a prime number") + break + else: + print("Prime number") + + +# Driver code +n = 17 +print_prime(n) +``` + +For the v1.5 version. ```bash $ cargo run --example phi --release -- --prompt "def print_prime(n): " diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 1dd507ff..321ea5de 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -268,7 +268,7 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?], WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?], - WhichModel::V2 => anyhow::bail!("phi-2 is not supported in quantized mode"), + WhichModel::V2 => vec![repo.get("model-v2-q4k.gguf")?], WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?], WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?], } @@ -298,7 +298,10 @@ fn main() -> Result<()> { }; let (model, device) = if args.quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?; - let model = QMixFormer::new(&config, vb)?; + let model = match args.model { + WhichModel::V2 => QMixFormer::new_v2(&config, vb)?, + _ => QMixFormer::new(&config, vb)?, + }; (Model::Quantized(model), Device::Cpu) } else { let device = candle_examples::device(args.cpu)?; diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs index f11f2036..1a3cd4ac 100644 --- a/candle-transformers/src/models/quantized_mixformer.rs +++ b/candle-transformers/src/models/quantized_mixformer.rs @@ -287,6 +287,24 @@ pub struct MixFormerSequentialForCausalLM { } impl MixFormerSequentialForCausalLM { + pub fn new_v2(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let vb_head = vb.pp("lm_head"); + let vb = vb.pp("transformer"); + let embedding = Embedding::new(cfg, vb.pp("embd"))?; + let mut blocks = Vec::new(); + for i in 0..cfg.n_layer { + let block = ParallelBlock::new(cfg, vb.pp("h").pp(i))?; + blocks.push(block) + } + let head = CausalLMHead::new(cfg, vb_head)?; + Ok(Self { + embedding, + blocks, + head, + span: tracing::span!(tracing::Level::TRACE, "mixformer"), + }) + } + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { let vb = vb.pp("layers"); let embedding = Embedding::new(cfg, vb.pp(0))?; |