summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-12-13 21:16:34 -0600
committerGitHub <noreply@github.com>2023-12-13 21:16:34 -0600
commit5e33c85c8f7d2ae8c5fe8de557b69c036e4f080a (patch)
tree786b7d376b517fd5f353fbfdfcb1a1236178ae26
parent2b3a018be7596d6c72aaee8a469312ce865498d5 (diff)
downloadcandle-5e33c85c8f7d2ae8c5fe8de557b69c036e4f080a.tar.gz
candle-5e33c85c8f7d2ae8c5fe8de557b69c036e4f080a.tar.bz2
candle-5e33c85c8f7d2ae8c5fe8de557b69c036e4f080a.zip
Quantized version for phi-v2. (#1430)
* Quantized version for phi-v2. * More quantized support.
-rw-r--r--candle-examples/examples/phi/README.md30
-rw-r--r--candle-examples/examples/phi/main.rs7
-rw-r--r--candle-transformers/src/models/quantized_mixformer.rs18
3 files changed, 49 insertions, 6 deletions
diff --git a/candle-examples/examples/phi/README.md b/candle-examples/examples/phi/README.md
index 566411d1..a84c01f2 100644
--- a/candle-examples/examples/phi/README.md
+++ b/candle-examples/examples/phi/README.md
@@ -1,14 +1,36 @@
-# candle-phi: 1.3b LLM with state of the art performance for <10b models.
+# candle-phi: 1.3b and 2.7b LLM with state of the art performance for <10b models.
-[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) is a language model using
-only 1.3 billion parameters but with state of the art performance compared to
+[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) and
+[Phi-2](https://huggingface.co/microsoft/phi-2) are language models using
+only 1.3 and 2.7 billion parameters but with state of the art performance compared to
models with up to 10 billion parameters.
The candle implementation provides both the standard version as well as a
quantized variant.
-## Running some example
+## Running some examples
+For the v2 version.
+```bash
+$ cargo run --example phi --release cuda -- --prompt "def print_prime(n): " --model 2
+def print_prime(n):
+ if n <= 1:
+ print("Not a prime number")
+ else:
+ for i in range(2, int(n**0.5)+1):
+ if (n % i) == 0:
+ print("Not a prime number")
+ break
+ else:
+ print("Prime number")
+
+
+# Driver code
+n = 17
+print_prime(n)
+```
+
+For the v1.5 version.
```bash
$ cargo run --example phi --release -- --prompt "def print_prime(n): "
diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs
index 1dd507ff..321ea5de 100644
--- a/candle-examples/examples/phi/main.rs
+++ b/candle-examples/examples/phi/main.rs
@@ -268,7 +268,7 @@ fn main() -> Result<()> {
match args.model {
WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
- WhichModel::V2 => anyhow::bail!("phi-2 is not supported in quantized mode"),
+ WhichModel::V2 => vec![repo.get("model-v2-q4k.gguf")?],
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
}
@@ -298,7 +298,10 @@ fn main() -> Result<()> {
};
let (model, device) = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
- let model = QMixFormer::new(&config, vb)?;
+ let model = match args.model {
+ WhichModel::V2 => QMixFormer::new_v2(&config, vb)?,
+ _ => QMixFormer::new(&config, vb)?,
+ };
(Model::Quantized(model), Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs
index f11f2036..1a3cd4ac 100644
--- a/candle-transformers/src/models/quantized_mixformer.rs
+++ b/candle-transformers/src/models/quantized_mixformer.rs
@@ -287,6 +287,24 @@ pub struct MixFormerSequentialForCausalLM {
}
impl MixFormerSequentialForCausalLM {
+ pub fn new_v2(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let vb_head = vb.pp("lm_head");
+ let vb = vb.pp("transformer");
+ let embedding = Embedding::new(cfg, vb.pp("embd"))?;
+ let mut blocks = Vec::new();
+ for i in 0..cfg.n_layer {
+ let block = ParallelBlock::new(cfg, vb.pp("h").pp(i))?;
+ blocks.push(block)
+ }
+ let head = CausalLMHead::new(cfg, vb_head)?;
+ Ok(Self {
+ embedding,
+ blocks,
+ head,
+ span: tracing::span!(tracing::Level::TRACE, "mixformer"),
+ })
+ }
+
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb = vb.pp("layers");
let embedding = Embedding::new(cfg, vb.pp(0))?;