summaryrefslogtreecommitdiff
path: root/candle-transformers/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-02-14 10:58:32 +0100
committerGitHub <noreply@github.com>2024-02-14 10:58:32 +0100
commit2d5f2a728d9ade10ce4b7b618ee4dba8075064dd (patch)
tree304d99d8330c116bea92c2997474311c199e579a /candle-transformers/src
parent68f76558956f7f56cb5014bb5f7c7c5534436b72 (diff)
downloadcandle-2d5f2a728d9ade10ce4b7b618ee4dba8075064dd.tar.gz
candle-2d5f2a728d9ade10ce4b7b618ee4dba8075064dd.tar.bz2
candle-2d5f2a728d9ade10ce4b7b618ee4dba8075064dd.zip
Add the RWKV model (v5). (#1707)
* Start adding the RWKV model. * More of the forward step. * Handle rescaling. * FeedForward. * More work on RWKV. * Better state tracking. * Finish a first pass on forward. * Fix the shape mismatches. * Do not rescale in f32. * Rename to rwkv-v5. * Add the new models to the readme.
Diffstat (limited to 'candle-transformers/src')
-rw-r--r--candle-transformers/src/models/llama.rs3
-rw-r--r--candle-transformers/src/models/mod.rs1
-rw-r--r--candle-transformers/src/models/rwkv_v5.rs317
3 files changed, 319 insertions, 2 deletions
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs
index 7a920cb8..f8126394 100644
--- a/candle-transformers/src/models/llama.rs
+++ b/candle-transformers/src/models/llama.rs
@@ -1,13 +1,12 @@
use super::with_tracing::{linear_no_bias as linear, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
-use serde::Deserialize;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
pub const MAX_SEQ_LEN: usize = 4096;
-#[derive(Debug, Clone, Deserialize)]
+#[derive(Debug, Clone, serde::Deserialize)]
pub struct LlamaConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 769fd650..8eab4744 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -34,6 +34,7 @@ pub mod quantized_t5;
pub mod qwen2;
pub mod repvgg;
pub mod resnet;
+pub mod rwkv_v5;
pub mod segment_anything;
pub mod stable_diffusion;
pub mod stable_lm;
diff --git a/candle-transformers/src/models/rwkv_v5.rs b/candle-transformers/src/models/rwkv_v5.rs
new file mode 100644
index 00000000..d8ea7b20
--- /dev/null
+++ b/candle-transformers/src/models/rwkv_v5.rs
@@ -0,0 +1,317 @@
+use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
+use candle::{DType, Device, IndexOp, Result, Tensor};
+use candle_nn::{embedding, Embedding, Module, VarBuilder};
+
+fn default_num_attention_heads() -> usize {
+ 64
+}
+
+// https://huggingface.co/RWKV/HF_v5-Eagle-7B/blob/main/configuration_rwkv5.py
+#[derive(Debug, Clone, serde::Deserialize)]
+pub struct Config {
+ pub vocab_size: usize,
+ pub hidden_size: usize,
+ pub num_hidden_layers: usize,
+ pub attention_hidden_size: usize,
+ #[serde(default = "default_num_attention_heads")]
+ pub num_attention_heads: usize,
+ pub head_size: usize,
+ pub intermediate_size: Option<usize>,
+ pub layer_norm_epsilon: f64,
+ pub rescale_every: usize,
+}
+
+struct StatePerLayer {
+ extract_key_value: Tensor,
+ linear_attention: Tensor,
+ feed_forward: Tensor,
+}
+
+pub struct State {
+ per_layer: Vec<StatePerLayer>,
+ pos: usize,
+}
+
+impl State {
+ pub fn new(batch_size: usize, cfg: &Config, dev: &Device) -> Result<Self> {
+ let mut per_layer = Vec::with_capacity(cfg.num_hidden_layers);
+ // Certainly a weird convention but taken from modeling_rwkv5.py
+ let num_attention_heads = cfg.hidden_size / cfg.num_attention_heads;
+ for _layer_idx in 0..cfg.num_hidden_layers {
+ let extract_key_value = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?;
+ let linear_attention = Tensor::zeros(
+ (
+ batch_size,
+ num_attention_heads,
+ cfg.hidden_size / num_attention_heads,
+ cfg.hidden_size / num_attention_heads,
+ ),
+ DType::F32,
+ dev,
+ )?;
+ let feed_forward = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?;
+ per_layer.push(StatePerLayer {
+ extract_key_value,
+ linear_attention,
+ feed_forward,
+ });
+ }
+ Ok(Self { per_layer, pos: 0 })
+ }
+}
+
+#[derive(Debug, Clone)]
+struct SelfAttention {
+ key: Linear,
+ receptance: Linear,
+ value: Linear,
+ gate: Linear,
+ output: Linear,
+ ln_x: candle_nn::GroupNorm,
+ time_mix_key: Tensor,
+ time_mix_value: Tensor,
+ time_mix_receptance: Tensor,
+ time_decay: Tensor,
+ time_faaaa: Tensor,
+ time_mix_gate: Tensor,
+ layer_id: usize,
+ n_attn_heads: usize,
+}
+
+impl SelfAttention {
+ pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let hidden_size = cfg.hidden_size;
+ let attn_hidden_size = cfg.attention_hidden_size;
+ let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
+ let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
+ let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
+ let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
+ let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
+ let ln_x = candle_nn::group_norm(
+ hidden_size / cfg.head_size,
+ hidden_size,
+ 1e-5,
+ vb.pp("ln_x"),
+ )?;
+ let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
+ let time_mix_value = vb.get((1, 1, cfg.hidden_size), "time_mix_value")?;
+ let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
+ let n_attn_heads = cfg.hidden_size / cfg.head_size;
+ let time_decay = vb.get((n_attn_heads, cfg.head_size), "time_decay")?;
+ let time_faaaa = vb.get((n_attn_heads, cfg.head_size), "time_faaaa")?;
+ let time_mix_gate = vb.get((1, 1, cfg.hidden_size), "time_mix_gate")?;
+ Ok(Self {
+ key,
+ value,
+ receptance,
+ gate,
+ output,
+ ln_x,
+ time_mix_key,
+ time_mix_value,
+ time_mix_receptance,
+ time_decay,
+ time_faaaa,
+ time_mix_gate,
+ layer_id,
+ n_attn_heads,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
+ let h = self.time_decay.dim(0)?;
+ let (b, t, s) = xs.dims3()?;
+ let s = s / h;
+ let (receptance, key, value, gate) = {
+ // exctract key-value
+ let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
+ let shifted = if shifted.rank() == 2 {
+ shifted.unsqueeze(1)?
+ } else {
+ shifted
+ };
+ let key = ((xs * &self.time_mix_key)? + &shifted * (1.0 - &self.time_mix_key)?)?;
+ let value = ((xs * &self.time_mix_value)? + &shifted * (1.0 - &self.time_mix_value)?)?;
+ let receptance = ((xs * &self.time_mix_receptance)?
+ + &shifted * (1.0 - &self.time_mix_receptance)?)?;
+ let gate = ((xs * &self.time_mix_gate)? + &shifted * (1.0 - &self.time_mix_gate)?)?;
+
+ let key = self.key.forward(&key)?;
+ let value = self.value.forward(&value)?;
+ let receptance = self.receptance.forward(&receptance)?;
+ let gate = candle_nn::ops::silu(&self.gate.forward(&gate)?)?;
+ state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
+ (receptance, key, value, gate)
+ };
+ // linear attention
+ let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
+ let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
+ let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
+ let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
+
+ let time_decay = self
+ .time_decay
+ .exp()?
+ .neg()?
+ .exp()?
+ .reshape(((), 1, 1))?
+ .reshape((self.n_attn_heads, (), 1))?;
+ let time_faaaa =
+ self.time_faaaa
+ .reshape(((), 1, 1))?
+ .reshape((self.n_attn_heads, (), 1))?;
+
+ let mut out: Vec<Tensor> = Vec::with_capacity(t);
+ for t_ in 0..t {
+ //
+ let rt = receptance.i((.., .., t_..t_ + 1))?;
+ let kt = key.i((.., .., .., t_..t_ + 1))?;
+ let vt = value.i((.., .., t_..t_ + 1))?;
+ let at = kt.matmul(&vt)?;
+ let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
+ let out_ = rt.matmul(&rhs)?.squeeze(2)?;
+ state_ = (&at + time_decay.broadcast_mul(&state_))?;
+ out.push(out_)
+ }
+ let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
+ let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
+ let out = (out * gate)?.apply(&self.output)?;
+ state.per_layer[self.layer_id].linear_attention = state_;
+ Ok(out)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct FeedForward {
+ time_mix_key: Tensor,
+ time_mix_receptance: Tensor,
+ key: Linear,
+ receptance: Linear,
+ value: Linear,
+ layer_id: usize,
+}
+
+impl FeedForward {
+ pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let int_size = cfg
+ .intermediate_size
+ .unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
+ let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
+ let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
+ let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
+ let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
+ let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
+ Ok(Self {
+ key,
+ receptance,
+ value,
+ time_mix_key,
+ time_mix_receptance,
+ layer_id,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
+ let shifted = &state.per_layer[self.layer_id].feed_forward;
+ let key = (xs.broadcast_mul(&self.time_mix_key)?
+ + shifted.broadcast_mul(&(1.0 - &self.time_mix_key)?)?)?;
+ let receptance = (xs.broadcast_mul(&self.time_mix_receptance)?
+ + shifted.broadcast_mul(&(1.0 - &self.time_mix_receptance)?)?)?;
+ let key = key.apply(&self.key)?.relu()?.sqr()?;
+ let value = key.apply(&self.value)?;
+ let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
+ state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
+ let xs = (receptance * value)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+struct Block {
+ pre_ln: Option<LayerNorm>,
+ ln1: LayerNorm,
+ ln2: LayerNorm,
+ attention: SelfAttention,
+ feed_forward: FeedForward,
+}
+
+impl Block {
+ pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
+ let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
+ let pre_ln = if layer_id == 0 {
+ let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
+ Some(ln)
+ } else {
+ None
+ };
+ let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
+ let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
+ Ok(Self {
+ pre_ln,
+ ln1,
+ ln2,
+ attention,
+ feed_forward,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
+ let xs = match self.pre_ln.as_ref() {
+ None => xs.clone(),
+ Some(pre_ln) => xs.apply(pre_ln)?,
+ };
+ let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
+ let xs = (xs + attention)?;
+ let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
+ let xs = (xs + feed_forward)?;
+ Ok(xs)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Model {
+ embeddings: Embedding,
+ blocks: Vec<Block>,
+ ln_out: LayerNorm,
+ head: Linear,
+ rescale_every: usize,
+ layers_are_rescaled: bool,
+}
+
+impl Model {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let vb_m = vb.pp("rwkv");
+ let embeddings = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
+ let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
+ let vb_b = vb_m.pp("blocks");
+ for block_index in 0..cfg.num_hidden_layers {
+ let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
+ blocks.push(block)
+ }
+ let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
+ let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
+ Ok(Self {
+ embeddings,
+ blocks,
+ ln_out,
+ head,
+ rescale_every: cfg.rescale_every,
+ layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
+ let (_b_size, _seq_len) = xs.dims2()?;
+ let mut xs = xs.apply(&self.embeddings)?;
+ for (block_idx, block) in self.blocks.iter().enumerate() {
+ xs = block.forward(&xs, state)?;
+ if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
+ xs = (xs / 2.)?
+ }
+ }
+ let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
+ state.pos += 1;
+ Ok(xs)
+ }
+}