diff options
author | Jack Shih <develop@kshih.com> | 2024-02-26 04:43:40 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-25 21:43:40 +0100 |
commit | 918136ba46e426a29ae9dc8318b23daa312d073e (patch) | |
tree | 13457de87fdd0e019265eb8aab75294e430606be /candle-examples | |
parent | 1a6043af5123bf9e189063d3baf110b39cf47617 (diff) | |
download | candle-918136ba46e426a29ae9dc8318b23daa312d073e.tar.gz candle-918136ba46e426a29ae9dc8318b23daa312d073e.tar.bz2 candle-918136ba46e426a29ae9dc8318b23daa312d073e.zip |
add quantized rwkv v5 model (#1743)
* and quantized rwkv v5 model
* Integrate the quantized rwkv model in the initial example.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples')
-rw-r--r-- | candle-examples/examples/rwkv/main.rs | 42 |
1 files changed, 38 insertions, 4 deletions
diff --git a/candle-examples/examples/rwkv/main.rs b/candle-examples/examples/rwkv/main.rs index 0ccf2ec3..771baa03 100644 --- a/candle-examples/examples/rwkv/main.rs +++ b/candle-examples/examples/rwkv/main.rs @@ -7,13 +7,28 @@ extern crate accelerate_src; use anyhow::Result; use clap::{Parser, ValueEnum}; -use candle_transformers::models::rwkv_v5::{Config, Model, State, Tokenizer}; +use candle_transformers::models::quantized_rwkv_v5::Model as Q; +use candle_transformers::models::rwkv_v5::{Config, Model as M, State, Tokenizer}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; +enum Model { + M(M), + Q(Q), +} + +impl Model { + fn forward(&self, xs: &Tensor, state: &mut State) -> candle::Result<Tensor> { + match self { + Self::M(m) => m.forward(xs, state), + Self::Q(m) => m.forward(xs, state), + } + } +} + struct TextGeneration { model: Model, config: Config, @@ -176,6 +191,9 @@ struct Args { #[arg(long)] config_file: Option<String>, + #[arg(long)] + quantized: bool, + /// Penalty to be applied for repeating tokens, 1. means no penalty. #[arg(long, default_value_t = 1.1)] repeat_penalty: f32, @@ -236,7 +254,16 @@ fn main() -> Result<()> { .map(std::path::PathBuf::from) .collect::<Vec<_>>(), None => { - vec![repo.get("model.safetensors")?] + if args.quantized { + let file = match args.which { + Which::World1b5 => "world1b5-q4k.gguf", + Which::World3b => "world3b-q4k.gguf", + Which::Eagle7b => "eagle7b-q4k.gguf", + }; + vec![api.model("lmz/candle-rwkv".to_string()).get(file)?] + } else { + vec![repo.get("model.safetensors")?] + } } }; println!("retrieved the files in {:?}", start.elapsed()); @@ -245,8 +272,15 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; let device = candle_examples::device(args.cpu)?; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; - let model = Model::new(&config, vb)?; + let model = if args.quantized { + let filename = &filenames[0]; + let vb = + candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?; + Model::Q(Q::new(&config, vb)?) + } else { + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; + Model::M(M::new(&config, vb)?) + }; println!("loaded the model in {:?}", start.elapsed()); let mut pipeline = TextGeneration::new( |