summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorJani Monoses <jani.monoses@gmail.com>2024-08-12 22:21:19 +0300
committerGitHub <noreply@github.com>2024-08-12 21:21:19 +0200
commit35e5f313977b6b1006ae98ee4443e0a27d14528d (patch)
tree40935d856959844968e44dc6ae5477f45432429d /candle-transformers
parentd3fe989d086a6317734e602b5106c9eccdb8745e (diff)
downloadcandle-35e5f313977b6b1006ae98ee4443e0a27d14528d.tar.gz
candle-35e5f313977b6b1006ae98ee4443e0a27d14528d.tar.bz2
candle-35e5f313977b6b1006ae98ee4443e0a27d14528d.zip
Add Based LLM from Hazy Research. (#2411)
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/based.rs589
-rw-r--r--candle-transformers/src/models/mod.rs1
2 files changed, 590 insertions, 0 deletions
diff --git a/candle-transformers/src/models/based.rs b/candle-transformers/src/models/based.rs
new file mode 100644
index 00000000..aa28f523
--- /dev/null
+++ b/candle-transformers/src/models/based.rs
@@ -0,0 +1,589 @@
+//! Based from the Stanford Hazy Research group.
+//!
+//! See "Simple linear attention language models balance the recall-throughput tradeoff", Arora et al. 2024
+//! <https://arxiv.org/abs/2402.18668>
+
+//! Original code:
+//! https://github.com/HazyResearch/based
+
+use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
+use candle_nn::{
+ conv1d_no_bias, linear, linear_no_bias, ops::softmax_last_dim, rms_norm, Conv1d, Conv1dConfig,
+ Func, Linear, RmsNorm, VarBuilder,
+};
+use std::sync::Arc;
+
+#[derive(Debug, Clone, serde::Deserialize)]
+pub struct LinearAttentionFeatureMapConfig {
+ input_dim: usize,
+}
+
+#[derive(Debug, Clone, serde::Deserialize)]
+pub struct LinearAttentionConfig {
+ num_heads: usize,
+ feature_dim: usize,
+ feature_map: LinearAttentionFeatureMapConfig,
+}
+
+#[derive(Debug, Clone, serde::Deserialize)]
+pub struct SlidingWindowAttentionConfig {
+ num_heads: usize,
+ window_size: usize,
+}
+
+#[derive(Debug, Clone, serde::Deserialize)]
+pub struct Config {
+ vocab_size: usize,
+ #[serde(rename = "n_embd")]
+ hidden_size: usize,
+ #[serde(rename = "n_inner")]
+ intermediate_size: usize,
+ #[serde(rename = "n_layer")]
+ num_hidden_layers: usize,
+ #[serde(rename = "n_head")]
+ num_attention_heads: usize,
+
+ layer_norm_epsilon: f64,
+ #[serde(default = "default_rope", rename = "rotary_emb_base")]
+ rope_theta: f64,
+
+ alt_mixer_layers: Vec<usize>,
+ alt_mixer_2_layers: Vec<usize>,
+ #[serde(rename = "alt_mixer")]
+ la: LinearAttentionConfig,
+ #[serde(rename = "alt_mixer_2")]
+ swa: SlidingWindowAttentionConfig,
+}
+
+fn default_rope() -> f64 {
+ 10_000.0
+}
+
+#[derive(Debug, Clone)]
+#[allow(clippy::upper_case_acronyms)]
+struct MLP {
+ fc1: Linear,
+ fc2: Linear,
+}
+
+impl MLP {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let fc1 = linear_no_bias(cfg.hidden_size, cfg.hidden_size * 4, vb.pp("fc1"))?;
+ let fc2 = linear_no_bias(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
+ Ok(Self { fc1, fc2 })
+ }
+}
+
+// Swiglu implementation.
+// Not using Activation::Swiglu because this has the gate and y arguments switched compared to the version in candle-nn/src/ops.rs
+fn swiglu(xs: &Tensor) -> Result<Tensor> {
+ let xs = xs.chunk(2, D::Minus1)?;
+ &xs[1].silu()? * &xs[0]
+}
+
+impl Module for MLP {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let xs = xs.apply(&self.fc1)?;
+ let xs = swiglu(&xs)?;
+ let xs = xs.apply(&self.fc2)?;
+ Ok(xs)
+ }
+}
+
+// A gated convolutional block.
+#[derive(Debug, Clone)]
+struct BasedConv {
+ in_proj: Linear,
+ out_proj: Linear,
+ conv: Conv1d,
+ state: Tensor,
+}
+
+impl BasedConv {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let dim = cfg.hidden_size * 2;
+
+ let conv1d_cfg = Conv1dConfig {
+ groups: dim,
+ padding: 2,
+ ..Default::default()
+ };
+
+ let in_proj = linear(cfg.hidden_size, cfg.hidden_size * 4, vb.pp("in_proj"))?;
+ let out_proj = linear(dim, cfg.hidden_size, vb.pp("out_proj"))?;
+ let conv = conv1d_no_bias(dim, dim, 3, conv1d_cfg, vb.pp("conv.conv"))?;
+ let state = Tensor::zeros((1, dim, 3), vb.dtype(), vb.device())?;
+ Ok(Self {
+ in_proj,
+ out_proj,
+ conv,
+ state,
+ })
+ }
+
+ fn step(&mut self, xs: &Tensor) -> Result<Tensor> {
+ self.state = self.state.roll(-1, D::Minus1)?;
+ let (_, _, l) = self.state.dims3()?;
+ self.state = self.state.narrow(D::Minus1, 0, l - 1)?;
+ self.state = Tensor::cat(&[&self.state, &xs.transpose(1, 2)?], 2)?;
+
+ let xs = (&self.state * self.conv.weight().permute((1, 0, 2))?)?
+ .sum_keepdim(0)?
+ .sum(D::Minus1)?;
+
+ let xs = xs.unsqueeze(1)?;
+
+ Ok(xs)
+ }
+
+ fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
+ let xs = xs.apply(&self.in_proj)?;
+ let us = xs.chunk(2, D::Minus1)?;
+ let (_b, l, _d) = us[0].dims3()?;
+ let u_conv = if seqlen_offset > 0 {
+ self.step(&us[0])?
+ } else {
+ let k = std::cmp::min(3, l);
+ self.state = self.state.narrow(D::Minus1, 0, 3 - k)?;
+ let xs = us[0].narrow(1, l - k, k)?.transpose(1, 2)?;
+ self.state = Tensor::cat(&[&self.state, &xs], 2)?;
+
+ us[0]
+ .transpose(1, 2)?
+ .apply(&self.conv)?
+ .narrow(D::Minus1, 0, l)?
+ .transpose(1, 2)?
+ };
+
+ let u_conv = u_conv.silu()?;
+ let v = u_conv.broadcast_mul(&us[1])?;
+ let xs = v.apply(&self.out_proj)?;
+
+ Ok(xs)
+ }
+}
+
+// Linear attention approximating softmax using second order Taylor polynomials.
+#[derive(Debug, Clone)]
+struct LinearAttention {
+ proj_q: Linear,
+ proj_k: Linear,
+ proj_v: Linear,
+ out_proj: Linear,
+ feature_dim: usize,
+ num_heads: usize,
+ input_dim: usize,
+ k_state: Tensor,
+ kv_state: Tensor,
+}
+
+impl LinearAttention {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let input_dim = cfg.la.feature_map.input_dim;
+ let out_proj = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("out_proj"))?;
+ let proj_k = linear_no_bias(
+ cfg.hidden_size,
+ cfg.la.num_heads * cfg.la.feature_dim,
+ vb.pp("proj_k"),
+ )?;
+ let proj_q = linear_no_bias(
+ cfg.hidden_size,
+ cfg.la.num_heads * cfg.la.feature_dim,
+ vb.pp("proj_q"),
+ )?;
+
+ let proj_v = linear_no_bias(cfg.hidden_size, cfg.hidden_size, vb.pp("proj_v"))?;
+ let expanded_size = cfg.la.feature_dim.pow(2) + cfg.la.feature_dim + 1;
+ let k_state = Tensor::zeros(
+ (1, cfg.la.num_heads, 1, 1, expanded_size),
+ vb.dtype(),
+ vb.device(),
+ )?;
+ let kv_state = Tensor::zeros(
+ (1, cfg.la.num_heads, cfg.la.feature_dim, expanded_size),
+ vb.dtype(),
+ vb.device(),
+ )?;
+
+ Ok(Self {
+ proj_q,
+ proj_k,
+ proj_v,
+ out_proj,
+ feature_dim: cfg.la.feature_dim,
+ num_heads: cfg.la.num_heads,
+ input_dim,
+ k_state,
+ kv_state,
+ })
+ }
+
+ fn taylor_expansion(&self) -> Result<Func<'static>> {
+ let r2 = std::f64::consts::SQRT_2;
+ let rd = (self.input_dim as f64).sqrt();
+ let rrd = rd.sqrt();
+
+ Ok(Func::new(move |xs| {
+ let dims = xs.dims();
+ let mut d = dims.to_vec();
+ if let Some(last) = d.last_mut() {
+ *last = 1;
+ };
+
+ let x = xs
+ .unsqueeze(D::Minus1)?
+ .broadcast_mul(&xs.unsqueeze(D::Minus2)?)?;
+ let x = (x.flatten_from(D::Minus2)? / r2)?;
+ let o = Tensor::ones(d, xs.dtype(), xs.device())?;
+ let x = Tensor::cat(&[o, (xs / rrd)?, (&x / rd)?], D::Minus1)?;
+
+ Ok(x)
+ }))
+ }
+
+ fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
+ let eps = 1e-12;
+
+ let feature_map = self.taylor_expansion()?;
+
+ let (b, l, d) = xs.dims3()?;
+ let q = xs.apply(&self.proj_q)?;
+ let k = xs.apply(&self.proj_k)?;
+ let v = xs.apply(&self.proj_v)?;
+
+ let q = q
+ .reshape((b, l, self.num_heads, self.feature_dim))?
+ .transpose(1, 2)?
+ .contiguous()?;
+ let k = k
+ .reshape((b, l, self.num_heads, self.feature_dim))?
+ .transpose(1, 2)?
+ .contiguous()?;
+ let v = v
+ .reshape((b, l, self.num_heads, d / self.num_heads))?
+ .transpose(1, 2)?
+ .contiguous()?;
+
+ let q = feature_map.forward(&q)?;
+ let k = feature_map.forward(&k)?;
+
+ let y = if seqlen_offset > 0 {
+ let (_b, _h, l, _d) = k.dims4()?;
+ let q = q.unsqueeze(D::Minus2)?;
+ let k = k.unsqueeze(D::Minus2)?;
+ let v = v.unsqueeze(D::Minus1)?;
+ let kn = k.narrow(D::Minus1, l - 1, 1)?;
+ let vn = v.narrow(D::Minus1, l - 1, 1)?;
+
+ self.k_state = self.k_state.broadcast_add(&kn)?;
+ self.kv_state = self.kv_state.broadcast_add(&kn.broadcast_mul(&vn)?)?;
+
+ let num = q.broadcast_mul(&self.kv_state)?.sum(D::Minus1)?;
+ let den = (q.broadcast_mul(&self.k_state)?.sum(D::Minus1)? + eps)?;
+ num.broadcast_div(&den)?
+ } else {
+ self.k_state = k.sum(2)?.unsqueeze(2)?.unsqueeze(3)?;
+ self.kv_state = k
+ .transpose(2, 3)?
+ .matmul(&v)?
+ .transpose(2, 3)?
+ .unsqueeze(2)?;
+ let aqk = q.matmul(&k.transpose(D::Minus1, D::Minus2)?)?;
+ let tril = Tensor::tril2(l, aqk.dtype(), aqk.device())?;
+ let aqk = aqk.broadcast_mul(&tril)?.matmul(&v)?;
+
+ let z = (1f64 / (q.mul(&k.cumsum(2)?)?.sum(D::Minus1)? + eps)?)?;
+ aqk.broadcast_mul(&z.unsqueeze(D::Minus1)?)?
+ };
+
+ let (b, h, l, d) = y.dims4()?;
+ let y = y.permute((0, 2, 1, 3))?.reshape((b, l, h * d))?;
+ let y = self.out_proj.forward(&y)?;
+
+ Ok(y)
+ }
+}
+
+// Rotary embeddings used in local attention.
+#[derive(Debug, Clone)]
+struct RotaryEmbedding {
+ sin: Tensor,
+ cos: Tensor,
+}
+
+impl RotaryEmbedding {
+ fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
+ let dim = cfg.hidden_size / cfg.num_attention_heads;
+ let max_seq_len = 2048; // Hardcoded, missing from config.
+ let inv_freq: Vec<_> = (0..dim)
+ .step_by(2)
+ .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
+ .collect();
+ let inv_freq_len = inv_freq.len();
+ let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
+ let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
+ .to_dtype(dtype)?
+ .reshape((max_seq_len, 1))?;
+ let freqs = t.matmul(&inv_freq)?;
+ Ok(Self {
+ sin: freqs.sin()?,
+ cos: freqs.cos()?,
+ })
+ }
+
+ fn apply_rotary_emb_qkv(
+ &self,
+ q: &Tensor,
+ k: &Tensor,
+ seqlen_offset: usize,
+ ) -> Result<(Tensor, Tensor)> {
+ let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
+ let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
+ let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
+ let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
+ let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
+ Ok((q_embed, k_embed))
+ }
+}
+
+// Local attention using a small sliding window.
+#[derive(Debug, Clone)]
+struct SlidingWindowAttention {
+ wqkv: Linear,
+ out_proj: Linear,
+ num_heads: usize,
+ head_dim: usize,
+ hidden_size: usize,
+ rotary_emb: Arc<RotaryEmbedding>,
+ kv_cache: Option<(Tensor, Tensor)>,
+}
+
+impl SlidingWindowAttention {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let hidden_size = cfg.hidden_size;
+ let num_heads = cfg.swa.num_heads;
+ let head_dim = hidden_size / num_heads;
+ let out_proj = linear_no_bias(hidden_size, hidden_size, vb.pp("out_proj"))?;
+ let wqkv = linear_no_bias(hidden_size, hidden_size * 3, vb.pp("Wqkv"))?;
+ let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
+ Ok(Self {
+ wqkv,
+ out_proj,
+ hidden_size,
+ num_heads,
+ head_dim,
+ rotary_emb,
+ kv_cache: None,
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ attention_mask: Option<&Tensor>,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let (b_sz, q_len, _) = xs.dims3()?;
+
+ let qkv = xs.apply(&self.wqkv)?;
+ let qkv = qkv.reshape((b_sz, q_len, 3, (), self.head_dim))?;
+
+ let q = qkv.i((.., .., 0))?;
+ let k = qkv.i((.., .., 1))?;
+ let v = qkv.i((.., .., 2))?;
+
+ let q = q
+ .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
+ .transpose(1, 2)?;
+ let k = k
+ .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
+ .transpose(1, 2)?;
+ let v = v
+ .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
+ .transpose(1, 2)?;
+
+ let (q, k) = self
+ .rotary_emb
+ .apply_rotary_emb_qkv(&q, &k, seqlen_offset)?;
+
+ let (k, v) = match &self.kv_cache {
+ None => (k, v),
+ Some((prev_k, prev_v)) => {
+ let k = Tensor::cat(&[prev_k, &k], 2)?;
+ let v = Tensor::cat(&[prev_v, &v], 2)?;
+ (k, v)
+ }
+ };
+ self.kv_cache = Some((k.clone(), v.clone()));
+
+ let scale = 1f64 / f64::sqrt(self.head_dim as f64);
+ let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
+
+ let attn_weights = match attention_mask {
+ None => attn_weights,
+ Some(mask) => attn_weights.broadcast_add(mask)?,
+ };
+ let attn_weights = softmax_last_dim(&attn_weights)?;
+ let attn_output = attn_weights.matmul(&v)?;
+ let out = attn_output
+ .transpose(1, 2)?
+ .reshape((b_sz, q_len, self.hidden_size))?
+ .apply(&self.out_proj)?;
+
+ Ok(out)
+ }
+}
+
+// The model layers use three types of mixers.
+#[derive(Debug, Clone)]
+enum SequenceMixer {
+ Based(BasedConv),
+ Linear(LinearAttention),
+ Sliding(SlidingWindowAttention),
+}
+
+impl SequenceMixer {
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ attention_mask: Option<&Tensor>,
+ pos: usize,
+ ) -> Result<Tensor> {
+ match self {
+ Self::Based(b) => b.forward(xs, pos),
+ Self::Linear(b) => b.forward(xs, pos),
+ Self::Sliding(b) => b.forward(xs, attention_mask, pos),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+struct DecoderLayer {
+ mlp: MLP,
+ norm1: RmsNorm,
+ norm2: RmsNorm,
+ mixer: SequenceMixer,
+}
+
+impl DecoderLayer {
+ fn new(layer_idx: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let mlp = MLP::new(cfg, vb.pp("mlp"))?;
+ let norm1 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm1"))?;
+ let norm2 = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("norm2"))?;
+
+ let l_attn = cfg.alt_mixer_layers.contains(&layer_idx);
+ let sw_attn = cfg.alt_mixer_2_layers.contains(&layer_idx);
+
+ let mixer = if l_attn {
+ SequenceMixer::Linear(LinearAttention::new(cfg, vb.pp("mixer"))?)
+ } else if sw_attn {
+ SequenceMixer::Sliding(SlidingWindowAttention::new(cfg, vb.pp("mixer"))?)
+ } else {
+ SequenceMixer::Based(BasedConv::new(cfg, vb.pp("mixer"))?)
+ };
+
+ Ok(Self {
+ mlp,
+ norm1,
+ norm2,
+ mixer,
+ })
+ }
+
+ fn forward(
+ &mut self,
+ xs: &Tensor,
+ attention_mask: Option<&Tensor>,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let residual = xs;
+ let xs = self.norm1.forward(xs)?;
+ let xs = self.mixer.forward(&xs, attention_mask, seqlen_offset)?;
+ let xs = (xs + residual)?;
+ let residual = &xs;
+ let xs = xs.apply(&self.norm2)?.apply(&self.mlp)?;
+ residual + xs
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Model {
+ embed_tokens: super::with_tracing::Embedding,
+ layers: Vec<DecoderLayer>,
+ norm: RmsNorm,
+ lm_head: Linear,
+ sliding_window: usize,
+ device: Device,
+ dtype: DType,
+}
+
+impl Model {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let vocab_size = cfg.vocab_size + (8 - cfg.vocab_size % 8) % 8;
+ let lm_head = linear_no_bias(cfg.hidden_size, vocab_size, vb.pp("lm_head"))?;
+ let embed_tokens = super::with_tracing::Embedding::from_weights(lm_head.weight().clone())?;
+ let vb_m = vb.pp("transformer");
+ let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
+ let vb_l = vb_m.pp("layers");
+ for layer_idx in 0..cfg.num_hidden_layers {
+ let layer = DecoderLayer::new(layer_idx, cfg, vb_l.pp(layer_idx))?;
+ layers.push(layer)
+ }
+ let norm = rms_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb_m.pp("ln_f"))?;
+ Ok(Self {
+ embed_tokens,
+ layers,
+ norm,
+ lm_head,
+ sliding_window: cfg.swa.window_size,
+ device: vb.device().clone(),
+ dtype: vb.dtype(),
+ })
+ }
+
+ fn prepare_decoder_attention_mask(
+ &self,
+ b_size: usize,
+ tgt_len: usize,
+ seqlen_offset: usize,
+ ) -> Result<Tensor> {
+ let sliding_window = self.sliding_window / 2;
+ let mask: Vec<_> = (0..tgt_len)
+ .flat_map(|i| {
+ (0..tgt_len).map(move |j| {
+ if i < j || j + sliding_window < i {
+ f32::NEG_INFINITY
+ } else {
+ 0.
+ }
+ })
+ })
+ .collect();
+ let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
+ let mask = if seqlen_offset > 0 {
+ let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
+ Tensor::cat(&[&mask0, &mask], D::Minus1)?
+ } else {
+ mask
+ };
+ mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
+ .to_dtype(self.dtype)
+ }
+
+ pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
+ let (b_size, seq_len) = input_ids.dims2()?;
+ let attention_mask = if seq_len <= 1 {
+ None
+ } else {
+ let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
+ Some(mask)
+ };
+ let mut xs = self.embed_tokens.forward(input_ids)?;
+ for layer in self.layers.iter_mut() {
+ xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
+ }
+ xs.narrow(1, seq_len - 1, 1)?
+ .apply(&self.norm)?
+ .apply(&self.lm_head)
+ }
+}
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index c0de550b..7baaaf72 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -1,3 +1,4 @@
+pub mod based;
pub mod beit;
pub mod bert;
pub mod bigcode;