diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-18 16:29:38 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-18 16:29:38 +0100 |
commit | 86e7d539d253740d5a0579e6f53acc12e30d3e4c (patch) | |
tree | 42186de2fb48a03d1473486b3f7a643704ea624a /candle-examples/examples/replit-code | |
parent | cb034506cdbf6f650038893762ac815669ddbb10 (diff) | |
download | candle-86e7d539d253740d5a0579e6f53acc12e30d3e4c.tar.gz candle-86e7d539d253740d5a0579e6f53acc12e30d3e4c.tar.bz2 candle-86e7d539d253740d5a0579e6f53acc12e30d3e4c.zip |
Add the quantized mpt model. (#1123)
* Add the quantized mpt model.
* Support the quantized model for replit-code.
Diffstat (limited to 'candle-examples/examples/replit-code')
-rw-r--r-- | candle-examples/examples/replit-code/main.rs | 41 |
1 files changed, 36 insertions, 5 deletions
diff --git a/candle-examples/examples/replit-code/main.rs b/candle-examples/examples/replit-code/main.rs index 82c6c980..0f72b862 100644 --- a/candle-examples/examples/replit-code/main.rs +++ b/candle-examples/examples/replit-code/main.rs @@ -7,7 +7,8 @@ extern crate accelerate_src; use anyhow::{Error as E, Result}; use clap::Parser; -use candle_transformers::models::mpt::{Config, Model}; +use candle_transformers::models::mpt::{Config, Model as M}; +use candle_transformers::models::quantized_mpt::Model as Q; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; @@ -15,6 +16,20 @@ use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; +enum Model { + M(M), + Q(Q), +} + +impl Model { + fn forward(&mut self, xs: &Tensor) -> candle::Result<Tensor> { + match self { + Self::M(model) => model.forward(xs), + Self::Q(model) => model.forward(xs), + } + } +} + struct TextGeneration { model: Model, device: Device, @@ -149,6 +164,9 @@ struct Args { revision: Option<String>, #[arg(long)] + quantized: bool, + + #[arg(long)] weight_file: Option<String>, #[arg(long)] @@ -206,16 +224,29 @@ fn main() -> Result<()> { }; let filename = match args.weight_file { Some(weight_file) => std::path::PathBuf::from(weight_file), - None => repo.get("model.safetensors")?, + None => { + if args.quantized { + repo.get("model-replit-code-v1_5-q4k.gguf")? + } else { + repo.get("model.safetensors")? + } + } }; println!("retrieved the files in {:?}", start.elapsed()); let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); let config = Config::replit_code_v1_5_3b(); - let device = candle_examples::device(args.cpu)?; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? }; - let model = Model::new(&config, vb.pp("transformer"))?; + let (model, device) = if args.quantized { + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?; + let model = Model::Q(Q::new(&config, vb.pp("transformer"))?); + (model, Device::Cpu) + } else { + let device = candle_examples::device(args.cpu)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? }; + let model = Model::M(M::new(&config, vb.pp("transformer"))?); + (model, device) + }; println!("loaded the model in {:?}", start.elapsed()); let mut pipeline = TextGeneration::new( |