summaryrefslogtreecommitdiff
path: root/candle-examples
diff options
context:
space:
mode:
authorJack Shih <develop@kshih.com>2024-02-26 04:43:40 +0800
committerGitHub <noreply@github.com>2024-02-25 21:43:40 +0100
commit918136ba46e426a29ae9dc8318b23daa312d073e (patch)
tree13457de87fdd0e019265eb8aab75294e430606be /candle-examples
parent1a6043af5123bf9e189063d3baf110b39cf47617 (diff)
downloadcandle-918136ba46e426a29ae9dc8318b23daa312d073e.tar.gz
candle-918136ba46e426a29ae9dc8318b23daa312d073e.tar.bz2
candle-918136ba46e426a29ae9dc8318b23daa312d073e.zip
add quantized rwkv v5 model (#1743)
* and quantized rwkv v5 model * Integrate the quantized rwkv model in the initial example. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples')
-rw-r--r--candle-examples/examples/rwkv/main.rs42
1 files changed, 38 insertions, 4 deletions
diff --git a/candle-examples/examples/rwkv/main.rs b/candle-examples/examples/rwkv/main.rs
index 0ccf2ec3..771baa03 100644
--- a/candle-examples/examples/rwkv/main.rs
+++ b/candle-examples/examples/rwkv/main.rs
@@ -7,13 +7,28 @@ extern crate accelerate_src;
use anyhow::Result;
use clap::{Parser, ValueEnum};
-use candle_transformers::models::rwkv_v5::{Config, Model, State, Tokenizer};
+use candle_transformers::models::quantized_rwkv_v5::Model as Q;
+use candle_transformers::models::rwkv_v5::{Config, Model as M, State, Tokenizer};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
+enum Model {
+ M(M),
+ Q(Q),
+}
+
+impl Model {
+ fn forward(&self, xs: &Tensor, state: &mut State) -> candle::Result<Tensor> {
+ match self {
+ Self::M(m) => m.forward(xs, state),
+ Self::Q(m) => m.forward(xs, state),
+ }
+ }
+}
+
struct TextGeneration {
model: Model,
config: Config,
@@ -176,6 +191,9 @@ struct Args {
#[arg(long)]
config_file: Option<String>,
+ #[arg(long)]
+ quantized: bool,
+
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
@@ -236,7 +254,16 @@ fn main() -> Result<()> {
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
- vec![repo.get("model.safetensors")?]
+ if args.quantized {
+ let file = match args.which {
+ Which::World1b5 => "world1b5-q4k.gguf",
+ Which::World3b => "world3b-q4k.gguf",
+ Which::Eagle7b => "eagle7b-q4k.gguf",
+ };
+ vec![api.model("lmz/candle-rwkv".to_string()).get(file)?]
+ } else {
+ vec![repo.get("model.safetensors")?]
+ }
}
};
println!("retrieved the files in {:?}", start.elapsed());
@@ -245,8 +272,15 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?;
- let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
- let model = Model::new(&config, vb)?;
+ let model = if args.quantized {
+ let filename = &filenames[0];
+ let vb =
+ candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
+ Model::Q(Q::new(&config, vb)?)
+ } else {
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
+ Model::M(M::new(&config, vb)?)
+ };
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(