diff options
Diffstat (limited to 'candle-transformers/src')
-rw-r--r-- | candle-transformers/src/models/llama.rs | 3 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/rwkv_v5.rs | 317 |
3 files changed, 319 insertions, 2 deletions
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index 7a920cb8..f8126394 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -1,13 +1,12 @@ use super::with_tracing::{linear_no_bias as linear, Linear}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; -use serde::Deserialize; use std::collections::HashMap; use std::sync::{Arc, Mutex}; pub const MAX_SEQ_LEN: usize = 4096; -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, serde::Deserialize)] pub struct LlamaConfig { pub hidden_size: usize, pub intermediate_size: usize, diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 769fd650..8eab4744 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -34,6 +34,7 @@ pub mod quantized_t5; pub mod qwen2; pub mod repvgg; pub mod resnet; +pub mod rwkv_v5; pub mod segment_anything; pub mod stable_diffusion; pub mod stable_lm; diff --git a/candle-transformers/src/models/rwkv_v5.rs b/candle-transformers/src/models/rwkv_v5.rs new file mode 100644 index 00000000..d8ea7b20 --- /dev/null +++ b/candle-transformers/src/models/rwkv_v5.rs @@ -0,0 +1,317 @@ +use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear}; +use candle::{DType, Device, IndexOp, Result, Tensor}; +use candle_nn::{embedding, Embedding, Module, VarBuilder}; + +fn default_num_attention_heads() -> usize { + 64 +} + +// https://huggingface.co/RWKV/HF_v5-Eagle-7B/blob/main/configuration_rwkv5.py +#[derive(Debug, Clone, serde::Deserialize)] +pub struct Config { + pub vocab_size: usize, + pub hidden_size: usize, + pub num_hidden_layers: usize, + pub attention_hidden_size: usize, + #[serde(default = "default_num_attention_heads")] + pub num_attention_heads: usize, + pub head_size: usize, + pub intermediate_size: Option<usize>, + pub layer_norm_epsilon: f64, + pub rescale_every: usize, +} + +struct StatePerLayer { + extract_key_value: Tensor, + linear_attention: Tensor, + feed_forward: Tensor, +} + +pub struct State { + per_layer: Vec<StatePerLayer>, + pos: usize, +} + +impl State { + pub fn new(batch_size: usize, cfg: &Config, dev: &Device) -> Result<Self> { + let mut per_layer = Vec::with_capacity(cfg.num_hidden_layers); + // Certainly a weird convention but taken from modeling_rwkv5.py + let num_attention_heads = cfg.hidden_size / cfg.num_attention_heads; + for _layer_idx in 0..cfg.num_hidden_layers { + let extract_key_value = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?; + let linear_attention = Tensor::zeros( + ( + batch_size, + num_attention_heads, + cfg.hidden_size / num_attention_heads, + cfg.hidden_size / num_attention_heads, + ), + DType::F32, + dev, + )?; + let feed_forward = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?; + per_layer.push(StatePerLayer { + extract_key_value, + linear_attention, + feed_forward, + }); + } + Ok(Self { per_layer, pos: 0 }) + } +} + +#[derive(Debug, Clone)] +struct SelfAttention { + key: Linear, + receptance: Linear, + value: Linear, + gate: Linear, + output: Linear, + ln_x: candle_nn::GroupNorm, + time_mix_key: Tensor, + time_mix_value: Tensor, + time_mix_receptance: Tensor, + time_decay: Tensor, + time_faaaa: Tensor, + time_mix_gate: Tensor, + layer_id: usize, + n_attn_heads: usize, +} + +impl SelfAttention { + pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> { + let hidden_size = cfg.hidden_size; + let attn_hidden_size = cfg.attention_hidden_size; + let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?; + let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?; + let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?; + let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?; + let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?; + let ln_x = candle_nn::group_norm( + hidden_size / cfg.head_size, + hidden_size, + 1e-5, + vb.pp("ln_x"), + )?; + let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?; + let time_mix_value = vb.get((1, 1, cfg.hidden_size), "time_mix_value")?; + let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?; + let n_attn_heads = cfg.hidden_size / cfg.head_size; + let time_decay = vb.get((n_attn_heads, cfg.head_size), "time_decay")?; + let time_faaaa = vb.get((n_attn_heads, cfg.head_size), "time_faaaa")?; + let time_mix_gate = vb.get((1, 1, cfg.hidden_size), "time_mix_gate")?; + Ok(Self { + key, + value, + receptance, + gate, + output, + ln_x, + time_mix_key, + time_mix_value, + time_mix_receptance, + time_decay, + time_faaaa, + time_mix_gate, + layer_id, + n_attn_heads, + }) + } + + pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> { + let h = self.time_decay.dim(0)?; + let (b, t, s) = xs.dims3()?; + let s = s / h; + let (receptance, key, value, gate) = { + // exctract key-value + let shifted = state.per_layer[self.layer_id].extract_key_value.clone(); + let shifted = if shifted.rank() == 2 { + shifted.unsqueeze(1)? + } else { + shifted + }; + let key = ((xs * &self.time_mix_key)? + &shifted * (1.0 - &self.time_mix_key)?)?; + let value = ((xs * &self.time_mix_value)? + &shifted * (1.0 - &self.time_mix_value)?)?; + let receptance = ((xs * &self.time_mix_receptance)? + + &shifted * (1.0 - &self.time_mix_receptance)?)?; + let gate = ((xs * &self.time_mix_gate)? + &shifted * (1.0 - &self.time_mix_gate)?)?; + + let key = self.key.forward(&key)?; + let value = self.value.forward(&value)?; + let receptance = self.receptance.forward(&receptance)?; + let gate = candle_nn::ops::silu(&self.gate.forward(&gate)?)?; + state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?; + (receptance, key, value, gate) + }; + // linear attention + let mut state_ = state.per_layer[self.layer_id].linear_attention.clone(); + let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?; + let value = value.reshape((b, t, h, s))?.transpose(1, 2)?; + let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?; + + let time_decay = self + .time_decay + .exp()? + .neg()? + .exp()? + .reshape(((), 1, 1))? + .reshape((self.n_attn_heads, (), 1))?; + let time_faaaa = + self.time_faaaa + .reshape(((), 1, 1))? + .reshape((self.n_attn_heads, (), 1))?; + + let mut out: Vec<Tensor> = Vec::with_capacity(t); + for t_ in 0..t { + // + let rt = receptance.i((.., .., t_..t_ + 1))?; + let kt = key.i((.., .., .., t_..t_ + 1))?; + let vt = value.i((.., .., t_..t_ + 1))?; + let at = kt.matmul(&vt)?; + let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?; + let out_ = rt.matmul(&rhs)?.squeeze(2)?; + state_ = (&at + time_decay.broadcast_mul(&state_))?; + out.push(out_) + } + let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?; + let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?; + let out = (out * gate)?.apply(&self.output)?; + state.per_layer[self.layer_id].linear_attention = state_; + Ok(out) + } +} + +#[derive(Debug, Clone)] +struct FeedForward { + time_mix_key: Tensor, + time_mix_receptance: Tensor, + key: Linear, + receptance: Linear, + value: Linear, + layer_id: usize, +} + +impl FeedForward { + pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> { + let int_size = cfg + .intermediate_size + .unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32); + let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?; + let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?; + let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?; + let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?; + let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?; + Ok(Self { + key, + receptance, + value, + time_mix_key, + time_mix_receptance, + layer_id, + }) + } + + pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> { + let shifted = &state.per_layer[self.layer_id].feed_forward; + let key = (xs.broadcast_mul(&self.time_mix_key)? + + shifted.broadcast_mul(&(1.0 - &self.time_mix_key)?)?)?; + let receptance = (xs.broadcast_mul(&self.time_mix_receptance)? + + shifted.broadcast_mul(&(1.0 - &self.time_mix_receptance)?)?)?; + let key = key.apply(&self.key)?.relu()?.sqr()?; + let value = key.apply(&self.value)?; + let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?; + state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?; + let xs = (receptance * value)?; + Ok(xs) + } +} + +#[derive(Debug, Clone)] +struct Block { + pre_ln: Option<LayerNorm>, + ln1: LayerNorm, + ln2: LayerNorm, + attention: SelfAttention, + feed_forward: FeedForward, +} + +impl Block { + pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> { + let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?; + let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?; + let pre_ln = if layer_id == 0 { + let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?; + Some(ln) + } else { + None + }; + let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?; + let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?; + Ok(Self { + pre_ln, + ln1, + ln2, + attention, + feed_forward, + }) + } + + pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> { + let xs = match self.pre_ln.as_ref() { + None => xs.clone(), + Some(pre_ln) => xs.apply(pre_ln)?, + }; + let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?; + let xs = (xs + attention)?; + let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?; + let xs = (xs + feed_forward)?; + Ok(xs) + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embeddings: Embedding, + blocks: Vec<Block>, + ln_out: LayerNorm, + head: Linear, + rescale_every: usize, + layers_are_rescaled: bool, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let vb_m = vb.pp("rwkv"); + let embeddings = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?; + let mut blocks = Vec::with_capacity(cfg.num_hidden_layers); + let vb_b = vb_m.pp("blocks"); + for block_index in 0..cfg.num_hidden_layers { + let block = Block::new(block_index, cfg, vb_b.pp(block_index))?; + blocks.push(block) + } + let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?; + let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?; + Ok(Self { + embeddings, + blocks, + ln_out, + head, + rescale_every: cfg.rescale_every, + layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes. + }) + } + + pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> { + let (_b_size, _seq_len) = xs.dims2()?; + let mut xs = xs.apply(&self.embeddings)?; + for (block_idx, block) in self.blocks.iter().enumerate() { + xs = block.forward(&xs, state)?; + if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 { + xs = (xs / 2.)? + } + } + let xs = xs.apply(&self.ln_out)?.apply(&self.head)?; + state.pos += 1; + Ok(xs) + } +} |