summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-transformers/src/models/gemma.rs16
-rw-r--r--candle-transformers/src/models/llama.rs12
-rw-r--r--candle-transformers/src/models/mistral.rs16
-rw-r--r--candle-transformers/src/models/mixtral.rs16
-rw-r--r--candle-transformers/src/models/mpt.rs18
-rw-r--r--candle-transformers/src/models/phi.rs10
-rw-r--r--candle-transformers/src/models/quantized_llama.rs20
-rw-r--r--candle-transformers/src/models/quantized_mistral.rs16
-rw-r--r--candle-transformers/src/models/quantized_mpt.rs4
-rw-r--r--candle-transformers/src/models/quantized_stable_lm.rs17
-rw-r--r--candle-transformers/src/models/qwen2.rs17
-rw-r--r--candle-transformers/src/models/qwen2_moe.rs17
-rw-r--r--candle-transformers/src/models/stable_lm.rs17
-rw-r--r--candle-transformers/src/models/starcoder2.rs16
-rw-r--r--candle-transformers/src/models/yi.rs16
-rw-r--r--candle-transformers/src/utils.rs14
16 files changed, 47 insertions, 195 deletions
diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs
index 15e4dccb..58b5f1e1 100644
--- a/candle-transformers/src/models/gemma.rs
+++ b/candle-transformers/src/models/gemma.rs
@@ -191,18 +191,6 @@ impl Attention {
})
}
- fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
- let n_rep = self.num_kv_groups;
- if n_rep == 1 {
- Ok(xs)
- } else {
- let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
- xs.unsqueeze(2)?
- .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
- .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
- }
- }
-
fn forward(
&mut self,
xs: &Tensor,
@@ -239,8 +227,8 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
- let key_states = self.repeat_kv(key_states)?.contiguous()?;
- let value_states = self.repeat_kv(value_states)?.contiguous()?;
+ let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
+ let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs
index 73671cdc..f3d482eb 100644
--- a/candle-transformers/src/models/llama.rs
+++ b/candle-transformers/src/models/llama.rs
@@ -256,17 +256,7 @@ impl CausalSelfAttention {
}
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
- let n_rep = self.num_attention_heads / self.num_key_value_heads;
- if n_rep == 1 {
- Ok(x)
- } else {
- let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
- let x = x
- .unsqueeze(2)?
- .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
- .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
- Ok(x)
- }
+ crate::utils::repeat_kv(x, self.num_attention_heads / self.num_key_value_heads)
}
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
diff --git a/candle-transformers/src/models/mistral.rs b/candle-transformers/src/models/mistral.rs
index d899c712..1cb55f9e 100644
--- a/candle-transformers/src/models/mistral.rs
+++ b/candle-transformers/src/models/mistral.rs
@@ -216,18 +216,6 @@ impl Attention {
})
}
- fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
- let n_rep = self.num_kv_groups;
- if n_rep == 1 {
- Ok(xs)
- } else {
- let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
- xs.unsqueeze(2)?
- .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
- .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
- }
- }
-
fn forward(
&mut self,
xs: &Tensor,
@@ -266,8 +254,8 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
- let key_states = self.repeat_kv(key_states)?;
- let value_states = self.repeat_kv(value_states)?;
+ let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
+ let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
let attn_output = if self.use_flash_attn {
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
diff --git a/candle-transformers/src/models/mixtral.rs b/candle-transformers/src/models/mixtral.rs
index f69c68e3..a578d6fe 100644
--- a/candle-transformers/src/models/mixtral.rs
+++ b/candle-transformers/src/models/mixtral.rs
@@ -158,18 +158,6 @@ impl Attention {
})
}
- fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
- let n_rep = self.num_kv_groups;
- if n_rep == 1 {
- Ok(xs)
- } else {
- let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
- xs.unsqueeze(2)?
- .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
- .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
- }
- }
-
fn forward(
&mut self,
xs: &Tensor,
@@ -206,8 +194,8 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
- let key_states = self.repeat_kv(key_states)?;
- let value_states = self.repeat_kv(value_states)?;
+ let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
+ let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
let attn_output = if self.use_flash_attn {
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
diff --git a/candle-transformers/src/models/mpt.rs b/candle-transformers/src/models/mpt.rs
index 093e177c..d46524fc 100644
--- a/candle-transformers/src/models/mpt.rs
+++ b/candle-transformers/src/models/mpt.rs
@@ -104,8 +104,8 @@ impl GroupedQueryAttention {
};
self.kv_cache = Some((key.clone(), value.clone()));
let query = query.contiguous()?;
- let key = repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
- let value = repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
+ let key = crate::utils::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
+ let value = crate::utils::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
let attn_bias = {
let s_q = query.dim(D::Minus2)?;
@@ -134,20 +134,6 @@ impl GroupedQueryAttention {
}
}
-// This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
-// The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
-// (batch, num_attention_heads, seqlen, head_dim)
-pub(crate) fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
- if n_rep == 1 {
- Ok(xs)
- } else {
- let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
- xs.unsqueeze(2)?
- .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
- .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
- }
-}
-
#[derive(Debug, Clone)]
struct Ffn {
up_proj: Linear,
diff --git a/candle-transformers/src/models/phi.rs b/candle-transformers/src/models/phi.rs
index 8bf357e7..2c7fccef 100644
--- a/candle-transformers/src/models/phi.rs
+++ b/candle-transformers/src/models/phi.rs
@@ -174,15 +174,7 @@ impl Attention {
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
- let n_rep = self.num_heads / self.num_kv_heads;
- if n_rep == 1 {
- Ok(xs)
- } else {
- let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
- xs.unsqueeze(2)?
- .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
- .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
- }
+ crate::utils::repeat_kv(xs, self.num_heads / self.num_kv_heads)
}
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs
index e1519b2d..6b326fbe 100644
--- a/candle-transformers/src/models/quantized_llama.rs
+++ b/candle-transformers/src/models/quantized_llama.rs
@@ -205,9 +205,9 @@ impl LayerWeights {
};
self.kv_cache = Some((k.clone(), v.clone()));
- // Support for MQA, useful for 70B models.
- let k = self.repeat_kv(k)?;
- let v = self.repeat_kv(v)?;
+ // Support for MQA, useful for 70B models and mistral.
+ let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
+ let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let att = match mask {
@@ -224,20 +224,6 @@ impl LayerWeights {
let y = self.attention_wo.forward(&y)?;
Ok(y)
}
-
- fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
- let n_rep = self.n_head / self.n_kv_head;
- if n_rep == 1 {
- Ok(x)
- } else {
- let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
- let x = x
- .unsqueeze(2)?
- .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
- .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
- Ok(x)
- }
- }
}
#[derive(Debug, Clone)]
diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs
index e37785de..0583810a 100644
--- a/candle-transformers/src/models/quantized_mistral.rs
+++ b/candle-transformers/src/models/quantized_mistral.rs
@@ -122,18 +122,6 @@ impl Attention {
})
}
- fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
- let n_rep = self.num_kv_groups;
- if n_rep == 1 {
- Ok(xs)
- } else {
- let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
- xs.unsqueeze(2)?
- .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
- .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
- }
- }
-
fn forward(
&mut self,
xs: &Tensor,
@@ -172,8 +160,8 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
- let key_states = self.repeat_kv(key_states)?;
- let value_states = self.repeat_kv(value_states)?;
+ let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
+ let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
diff --git a/candle-transformers/src/models/quantized_mpt.rs b/candle-transformers/src/models/quantized_mpt.rs
index 70a9e125..056fcac2 100644
--- a/candle-transformers/src/models/quantized_mpt.rs
+++ b/candle-transformers/src/models/quantized_mpt.rs
@@ -71,8 +71,8 @@ impl GroupedQueryAttention {
};
self.kv_cache = Some((key.clone(), value.clone()));
let query = query.contiguous()?;
- let key = super::mpt::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
- let value = super::mpt::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
+ let key = crate::utils::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
+ let value = crate::utils::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
let attn_bias = {
let s_q = query.dim(D::Minus2)?;
diff --git a/candle-transformers/src/models/quantized_stable_lm.rs b/candle-transformers/src/models/quantized_stable_lm.rs
index 7d4385a7..da447522 100644
--- a/candle-transformers/src/models/quantized_stable_lm.rs
+++ b/candle-transformers/src/models/quantized_stable_lm.rs
@@ -94,18 +94,6 @@ impl Attention {
})
}
- fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
- let n_rep = self.num_kv_groups;
- if n_rep == 1 {
- Ok(xs)
- } else {
- let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
- xs.unsqueeze(2)?
- .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
- .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
- }
- }
-
fn forward(
&mut self,
xs: &Tensor,
@@ -152,8 +140,9 @@ impl Attention {
self.kv_cache = Some((key_states.clone(), value_states.clone()));
}
- let key_states = self.repeat_kv(key_states)?.contiguous()?;
- let value_states = self.repeat_kv(value_states)?.contiguous()?;
+ let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
+ let value_states =
+ crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs
index 9a12eba5..06f9069a 100644
--- a/candle-transformers/src/models/qwen2.rs
+++ b/candle-transformers/src/models/qwen2.rs
@@ -146,18 +146,6 @@ impl Attention {
})
}
- fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
- let n_rep = self.num_kv_groups;
- if n_rep == 1 {
- Ok(xs)
- } else {
- let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
- xs.unsqueeze(2)?
- .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
- .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
- }
- }
-
fn forward(
&mut self,
xs: &Tensor,
@@ -194,8 +182,9 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
- let key_states = self.repeat_kv(key_states)?.contiguous()?;
- let value_states = self.repeat_kv(value_states)?.contiguous()?;
+ let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
+ let value_states =
+ crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
diff --git a/candle-transformers/src/models/qwen2_moe.rs b/candle-transformers/src/models/qwen2_moe.rs
index d6566e90..5650e350 100644
--- a/candle-transformers/src/models/qwen2_moe.rs
+++ b/candle-transformers/src/models/qwen2_moe.rs
@@ -151,18 +151,6 @@ impl Attention {
})
}
- fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
- let n_rep = self.num_kv_groups;
- if n_rep == 1 {
- Ok(xs)
- } else {
- let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
- xs.unsqueeze(2)?
- .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
- .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
- }
- }
-
fn forward(
&mut self,
xs: &Tensor,
@@ -199,8 +187,9 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
- let key_states = self.repeat_kv(key_states)?.contiguous()?;
- let value_states = self.repeat_kv(value_states)?.contiguous()?;
+ let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
+ let value_states =
+ crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
diff --git a/candle-transformers/src/models/stable_lm.rs b/candle-transformers/src/models/stable_lm.rs
index a1d58936..2b46e8a1 100644
--- a/candle-transformers/src/models/stable_lm.rs
+++ b/candle-transformers/src/models/stable_lm.rs
@@ -217,18 +217,6 @@ impl Attention {
})
}
- fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
- let n_rep = self.num_kv_groups;
- if n_rep == 1 {
- Ok(xs)
- } else {
- let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
- xs.unsqueeze(2)?
- .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
- .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
- }
- }
-
fn forward(
&mut self,
xs: &Tensor,
@@ -275,8 +263,9 @@ impl Attention {
self.kv_cache = Some((key_states.clone(), value_states.clone()));
}
- let key_states = self.repeat_kv(key_states)?.contiguous()?;
- let value_states = self.repeat_kv(value_states)?.contiguous()?;
+ let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
+ let value_states =
+ crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
let attn_output = if self.use_flash_attn {
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
diff --git a/candle-transformers/src/models/starcoder2.rs b/candle-transformers/src/models/starcoder2.rs
index da3f6799..d108d062 100644
--- a/candle-transformers/src/models/starcoder2.rs
+++ b/candle-transformers/src/models/starcoder2.rs
@@ -139,18 +139,6 @@ impl Attention {
})
}
- fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
- let n_rep = self.num_kv_groups;
- if n_rep == 1 {
- Ok(xs)
- } else {
- let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
- xs.unsqueeze(2)?
- .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
- .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
- }
- }
-
fn forward(
&mut self,
xs: &Tensor,
@@ -187,8 +175,8 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
- let key_states = self.repeat_kv(key_states)?;
- let value_states = self.repeat_kv(value_states)?;
+ let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
+ let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
diff --git a/candle-transformers/src/models/yi.rs b/candle-transformers/src/models/yi.rs
index 99d9de1b..df78ddce 100644
--- a/candle-transformers/src/models/yi.rs
+++ b/candle-transformers/src/models/yi.rs
@@ -175,18 +175,6 @@ impl Attention {
})
}
- fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
- let n_rep = self.num_kv_groups;
- if n_rep == 1 {
- Ok(xs)
- } else {
- let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
- xs.unsqueeze(2)?
- .expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
- .reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
- }
- }
-
fn forward(
&mut self,
xs: &Tensor,
@@ -223,8 +211,8 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
- let key_states = self.repeat_kv(key_states)?;
- let value_states = self.repeat_kv(value_states)?;
+ let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
+ let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
diff --git a/candle-transformers/src/utils.rs b/candle-transformers/src/utils.rs
index d29995ed..17e83694 100644
--- a/candle-transformers/src/utils.rs
+++ b/candle-transformers/src/utils.rs
@@ -20,3 +20,17 @@ pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> R
let logits_len = logits.len();
Tensor::from_vec(logits, logits_len, device)
}
+
+/// Repeats a key or value tensor for grouped query attention
+/// The input tensor should have a shape `(batch, num_kv_heads, seq_len, head_dim)`,
+pub fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
+ if n_rep == 1 {
+ Ok(xs)
+ } else {
+ let (b_sz, n_kv_head, seq_len, head_dim) = xs.dims4()?;
+ // Using cat is faster than a broadcast as it avoids going through a potentially
+ // strided copy.
+ // https://github.com/huggingface/candle/pull/2043
+ Tensor::cat(&vec![&xs; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
+ }
+}