diff options
author | Jack Shih <develop@kshih.com> | 2024-02-26 04:43:40 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-25 21:43:40 +0100 |
commit | 918136ba46e426a29ae9dc8318b23daa312d073e (patch) | |
tree | 13457de87fdd0e019265eb8aab75294e430606be | |
parent | 1a6043af5123bf9e189063d3baf110b39cf47617 (diff) | |
download | candle-918136ba46e426a29ae9dc8318b23daa312d073e.tar.gz candle-918136ba46e426a29ae9dc8318b23daa312d073e.tar.bz2 candle-918136ba46e426a29ae9dc8318b23daa312d073e.zip |
add quantized rwkv v5 model (#1743)
* and quantized rwkv v5 model
* Integrate the quantized rwkv model in the initial example.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
-rw-r--r-- | candle-examples/examples/rwkv/main.rs | 42 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_rwkv_v5.rs | 286 | ||||
-rw-r--r-- | candle-transformers/src/models/rwkv_v5.rs | 3 |
4 files changed, 326 insertions, 6 deletions
diff --git a/candle-examples/examples/rwkv/main.rs b/candle-examples/examples/rwkv/main.rs index 0ccf2ec3..771baa03 100644 --- a/candle-examples/examples/rwkv/main.rs +++ b/candle-examples/examples/rwkv/main.rs @@ -7,13 +7,28 @@ extern crate accelerate_src; use anyhow::Result; use clap::{Parser, ValueEnum}; -use candle_transformers::models::rwkv_v5::{Config, Model, State, Tokenizer}; +use candle_transformers::models::quantized_rwkv_v5::Model as Q; +use candle_transformers::models::rwkv_v5::{Config, Model as M, State, Tokenizer}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; +enum Model { + M(M), + Q(Q), +} + +impl Model { + fn forward(&self, xs: &Tensor, state: &mut State) -> candle::Result<Tensor> { + match self { + Self::M(m) => m.forward(xs, state), + Self::Q(m) => m.forward(xs, state), + } + } +} + struct TextGeneration { model: Model, config: Config, @@ -176,6 +191,9 @@ struct Args { #[arg(long)] config_file: Option<String>, + #[arg(long)] + quantized: bool, + /// Penalty to be applied for repeating tokens, 1. means no penalty. #[arg(long, default_value_t = 1.1)] repeat_penalty: f32, @@ -236,7 +254,16 @@ fn main() -> Result<()> { .map(std::path::PathBuf::from) .collect::<Vec<_>>(), None => { - vec![repo.get("model.safetensors")?] + if args.quantized { + let file = match args.which { + Which::World1b5 => "world1b5-q4k.gguf", + Which::World3b => "world3b-q4k.gguf", + Which::Eagle7b => "eagle7b-q4k.gguf", + }; + vec![api.model("lmz/candle-rwkv".to_string()).get(file)?] + } else { + vec![repo.get("model.safetensors")?] + } } }; println!("retrieved the files in {:?}", start.elapsed()); @@ -245,8 +272,15 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; let device = candle_examples::device(args.cpu)?; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; - let model = Model::new(&config, vb)?; + let model = if args.quantized { + let filename = &filenames[0]; + let vb = + candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?; + Model::Q(Q::new(&config, vb)?) + } else { + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; + Model::M(M::new(&config, vb)?) + }; println!("loaded the model in {:?}", start.elapsed()); let mut pipeline = TextGeneration::new( diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index bb59a53f..96627683 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -30,6 +30,7 @@ pub mod quantized_llama2_c; pub mod quantized_mistral; pub mod quantized_mixformer; pub mod quantized_mpt; +pub mod quantized_rwkv_v5; pub mod quantized_stable_lm; pub mod quantized_t5; pub mod qwen2; diff --git a/candle-transformers/src/models/quantized_rwkv_v5.rs b/candle-transformers/src/models/quantized_rwkv_v5.rs new file mode 100644 index 00000000..c41d7b4e --- /dev/null +++ b/candle-transformers/src/models/quantized_rwkv_v5.rs @@ -0,0 +1,286 @@ +use crate::{ + quantized_nn::{layer_norm, linear_no_bias as linear, Embedding, Linear}, + quantized_var_builder::VarBuilder, +}; +use candle::{IndexOp, Result, Tensor}; +use candle_nn::{GroupNorm, LayerNorm, Module}; + +pub use crate::models::rwkv_v5::{Config, State, Tokenizer}; + +#[derive(Debug, Clone)] +struct SelfAttention { + key: Linear, + receptance: Linear, + value: Linear, + gate: Linear, + output: Linear, + ln_x: candle_nn::GroupNorm, + time_mix_key: Tensor, + time_mix_value: Tensor, + time_mix_receptance: Tensor, + time_decay: Tensor, + time_faaaa: Tensor, + time_mix_gate: Tensor, + layer_id: usize, + n_attn_heads: usize, +} + +impl SelfAttention { + fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> { + let hidden_size = cfg.hidden_size; + let attn_hidden_size = cfg.attention_hidden_size; + let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?; + let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?; + let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?; + let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?; + let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?; + + let vb_x = vb.pp("ln_x"); + let ln_x_weight = vb_x.get(hidden_size, "weight")?.dequantize(vb.device())?; + let ln_x_bias = vb_x.get(hidden_size, "bias")?.dequantize(vb.device())?; + + let ln_x = GroupNorm::new( + ln_x_weight, + ln_x_bias, + hidden_size, + hidden_size / cfg.head_size, + 1e-5, + )?; + + let time_mix_key = vb + .get((1, 1, cfg.hidden_size), "time_mix_key")? + .dequantize(vb.device())?; + let time_mix_value = vb + .get((1, 1, cfg.hidden_size), "time_mix_value")? + .dequantize(vb.device())?; + let time_mix_receptance = vb + .get((1, 1, cfg.hidden_size), "time_mix_receptance")? + .dequantize(vb.device())?; + let n_attn_heads = cfg.hidden_size / cfg.head_size; + let time_decay = vb + .get((n_attn_heads, cfg.head_size), "time_decay")? + .dequantize(vb.device())?; + let time_faaaa = vb + .get((n_attn_heads, cfg.head_size), "time_faaaa")? + .dequantize(vb.device())?; + let time_mix_gate = vb + .get((1, 1, cfg.hidden_size), "time_mix_gate")? + .dequantize(vb.device())?; + Ok(Self { + key, + value, + receptance, + gate, + output, + ln_x, + time_mix_key, + time_mix_value, + time_mix_receptance, + time_decay, + time_faaaa, + time_mix_gate, + layer_id, + n_attn_heads, + }) + } + + pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> { + let h = self.time_decay.dim(0)?; + let (b, t, s) = xs.dims3()?; + let s = s / h; + let (receptance, key, value, gate) = { + // extract key-value + let shifted = state.per_layer[self.layer_id].extract_key_value.clone(); + let shifted = if shifted.rank() == 2 { + shifted.unsqueeze(1)? + } else { + shifted + }; + let key = ((xs * &self.time_mix_key)? + &shifted * (1.0 - &self.time_mix_key)?)?; + let value = ((xs * &self.time_mix_value)? + &shifted * (1.0 - &self.time_mix_value)?)?; + let receptance = ((xs * &self.time_mix_receptance)? + + &shifted * (1.0 - &self.time_mix_receptance)?)?; + let gate = ((xs * &self.time_mix_gate)? + &shifted * (1.0 - &self.time_mix_gate)?)?; + + let key = self.key.forward(&key)?; + let value = self.value.forward(&value)?; + let receptance = self.receptance.forward(&receptance)?; + let gate = candle_nn::ops::silu(&self.gate.forward(&gate)?)?; + state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?; + (receptance, key, value, gate) + }; + // linear attention + let mut state_ = state.per_layer[self.layer_id].linear_attention.clone(); + let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?; + let value = value.reshape((b, t, h, s))?.transpose(1, 2)?; + let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?; + + let time_decay = self + .time_decay + .exp()? + .neg()? + .exp()? + .reshape(((), 1, 1))? + .reshape((self.n_attn_heads, (), 1))?; + let time_faaaa = + self.time_faaaa + .reshape(((), 1, 1))? + .reshape((self.n_attn_heads, (), 1))?; + + let mut out: Vec<Tensor> = Vec::with_capacity(t); + for t_ in 0..t { + let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?; + let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?; + let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?; + let at = kt.matmul(&vt)?; + let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?; + let out_ = rt.matmul(&rhs)?.squeeze(2)?; + state_ = (&at + time_decay.broadcast_mul(&state_))?; + out.push(out_) + } + let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?; + let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?; + let out = (out * gate)?.apply(&self.output)?; + state.per_layer[self.layer_id].linear_attention = state_; + Ok(out) + } +} + +#[derive(Debug, Clone)] +struct FeedForward { + time_mix_key: Tensor, + time_mix_receptance: Tensor, + key: Linear, + receptance: Linear, + value: Linear, + layer_id: usize, +} + +impl FeedForward { + fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> { + let int_size = cfg + .intermediate_size + .unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32); + let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?; + let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?; + let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?; + let time_mix_key = vb + .get((1, 1, cfg.hidden_size), "time_mix_key")? + .dequantize(vb.device())?; + let time_mix_receptance = vb + .get((1, 1, cfg.hidden_size), "time_mix_receptance")? + .dequantize(vb.device())?; + Ok(Self { + key, + receptance, + value, + time_mix_key, + time_mix_receptance, + layer_id, + }) + } + + fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> { + let shifted = &state.per_layer[self.layer_id].feed_forward; + let key = (xs.broadcast_mul(&self.time_mix_key)? + + shifted.broadcast_mul(&(1.0 - &self.time_mix_key)?)?)?; + let receptance = (xs.broadcast_mul(&self.time_mix_receptance)? + + shifted.broadcast_mul(&(1.0 - &self.time_mix_receptance)?)?)?; + let key = key.apply(&self.key)?.relu()?.sqr()?; + let value = key.apply(&self.value)?; + let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?; + state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?; + let xs = (receptance * value)?; + Ok(xs) + } +} + +#[derive(Debug, Clone)] +struct Block { + pre_ln: Option<LayerNorm>, + ln1: LayerNorm, + ln2: LayerNorm, + attention: SelfAttention, + feed_forward: FeedForward, +} + +impl Block { + fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> { + let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?; + let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?; + let pre_ln = if layer_id == 0 { + let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?; + Some(ln) + } else { + None + }; + let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?; + let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?; + Ok(Self { + pre_ln, + ln1, + ln2, + attention, + feed_forward, + }) + } + + fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> { + let xs = match self.pre_ln.as_ref() { + None => xs.clone(), + Some(pre_ln) => xs.apply(pre_ln)?, + }; + let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?; + let xs = (xs + attention)?; + let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?; + let xs = (xs + feed_forward)?; + Ok(xs) + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embeddings: Embedding, + blocks: Vec<Block>, + ln_out: LayerNorm, + head: Linear, + rescale_every: usize, + layers_are_rescaled: bool, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let vb_m = vb.pp("rwkv"); + let embeddings = Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?; + let mut blocks = Vec::with_capacity(cfg.num_hidden_layers); + let vb_b = vb_m.pp("blocks"); + for block_index in 0..cfg.num_hidden_layers { + let block = Block::new(block_index, cfg, vb_b.pp(block_index))?; + blocks.push(block) + } + let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?; + let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?; + Ok(Self { + embeddings, + blocks, + ln_out, + head, + rescale_every: cfg.rescale_every, + layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes. + }) + } + + pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> { + let (_b_size, _seq_len) = xs.dims2()?; + let mut xs = xs.apply(&self.embeddings)?; + for (block_idx, block) in self.blocks.iter().enumerate() { + xs = block.forward(&xs, state)?; + if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 { + xs = (xs / 2.)? + } + } + let xs = xs.apply(&self.ln_out)?.apply(&self.head)?; + state.pos += 1; + Ok(xs) + } +} diff --git a/candle-transformers/src/models/rwkv_v5.rs b/candle-transformers/src/models/rwkv_v5.rs index 38b1e450..eb512731 100644 --- a/candle-transformers/src/models/rwkv_v5.rs +++ b/candle-transformers/src/models/rwkv_v5.rs @@ -124,7 +124,7 @@ impl SelfAttention { let (b, t, s) = xs.dims3()?; let s = s / h; let (receptance, key, value, gate) = { - // exctract key-value + // extract key-value let shifted = state.per_layer[self.layer_id].extract_key_value.clone(); let shifted = if shifted.rank() == 2 { shifted.unsqueeze(1)? @@ -164,7 +164,6 @@ impl SelfAttention { let mut out: Vec<Tensor> = Vec::with_capacity(t); for t_ in 0..t { - // let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?; let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?; let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?; |