summaryrefslogtreecommitdiff
path: root/candle-examples/examples/replit-code
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-18 16:29:38 +0100
committerGitHub <noreply@github.com>2023-10-18 16:29:38 +0100
commit86e7d539d253740d5a0579e6f53acc12e30d3e4c (patch)
tree42186de2fb48a03d1473486b3f7a643704ea624a /candle-examples/examples/replit-code
parentcb034506cdbf6f650038893762ac815669ddbb10 (diff)
downloadcandle-86e7d539d253740d5a0579e6f53acc12e30d3e4c.tar.gz
candle-86e7d539d253740d5a0579e6f53acc12e30d3e4c.tar.bz2
candle-86e7d539d253740d5a0579e6f53acc12e30d3e4c.zip
Add the quantized mpt model. (#1123)
* Add the quantized mpt model. * Support the quantized model for replit-code.
Diffstat (limited to 'candle-examples/examples/replit-code')
-rw-r--r--candle-examples/examples/replit-code/main.rs41
1 files changed, 36 insertions, 5 deletions
diff --git a/candle-examples/examples/replit-code/main.rs b/candle-examples/examples/replit-code/main.rs
index 82c6c980..0f72b862 100644
--- a/candle-examples/examples/replit-code/main.rs
+++ b/candle-examples/examples/replit-code/main.rs
@@ -7,7 +7,8 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
-use candle_transformers::models::mpt::{Config, Model};
+use candle_transformers::models::mpt::{Config, Model as M};
+use candle_transformers::models::quantized_mpt::Model as Q;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
@@ -15,6 +16,20 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
+enum Model {
+ M(M),
+ Q(Q),
+}
+
+impl Model {
+ fn forward(&mut self, xs: &Tensor) -> candle::Result<Tensor> {
+ match self {
+ Self::M(model) => model.forward(xs),
+ Self::Q(model) => model.forward(xs),
+ }
+ }
+}
+
struct TextGeneration {
model: Model,
device: Device,
@@ -149,6 +164,9 @@ struct Args {
revision: Option<String>,
#[arg(long)]
+ quantized: bool,
+
+ #[arg(long)]
weight_file: Option<String>,
#[arg(long)]
@@ -206,16 +224,29 @@ fn main() -> Result<()> {
};
let filename = match args.weight_file {
Some(weight_file) => std::path::PathBuf::from(weight_file),
- None => repo.get("model.safetensors")?,
+ None => {
+ if args.quantized {
+ repo.get("model-replit-code-v1_5-q4k.gguf")?
+ } else {
+ repo.get("model.safetensors")?
+ }
+ }
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::replit_code_v1_5_3b();
- let device = candle_examples::device(args.cpu)?;
- let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
- let model = Model::new(&config, vb.pp("transformer"))?;
+ let (model, device) = if args.quantized {
+ let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
+ let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
+ (model, Device::Cpu)
+ } else {
+ let device = candle_examples::device(args.cpu)?;
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
+ let model = Model::M(M::new(&config, vb.pp("transformer"))?);
+ (model, device)
+ };
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(