summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-27 20:17:35 +0200
committerGitHub <noreply@github.com>2024-04-27 20:17:35 +0200
commit96a48e5cc42b3c94d9d9687bb29987953df36db8 (patch)
tree4f1f391e6e6a8c1b865c4ab40e67aaf84dd21499 /candle-transformers
parent6cf82fd7a34641601264ad1e0256ecadb7222474 (diff)
downloadcandle-96a48e5cc42b3c94d9d9687bb29987953df36db8.tar.gz
candle-96a48e5cc42b3c94d9d9687bb29987953df36db8.tar.bz2
candle-96a48e5cc42b3c94d9d9687bb29987953df36db8.zip
Add argsort. (#2132)
* Add the argsort cuda kernels. * CPU version of arg-sort. * Hook the cuda kernel + rework the cpu bits. * Add some dedicated test. * Working cuda kernel. * Metal kernel. * Metal adjustments. * Bugfix. * Use the fast rope in qwen. * Rework the expert selection in qwen.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/qwen2.rs14
-rw-r--r--candle-transformers/src/models/qwen2_moe.rs50
2 files changed, 21 insertions, 43 deletions
diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs
index 06f9069a..c9b5ae01 100644
--- a/candle-transformers/src/models/qwen2.rs
+++ b/candle-transformers/src/models/qwen2.rs
@@ -27,13 +27,6 @@ struct RotaryEmbedding {
cos: Tensor,
}
-fn rotate_half(xs: &Tensor) -> Result<Tensor> {
- let last_dim = xs.dim(D::Minus1)?;
- let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
- let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
- Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
-}
-
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.hidden_size / cfg.num_attention_heads;
@@ -48,7 +41,6 @@ impl RotaryEmbedding {
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
- let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
@@ -64,10 +56,8 @@ impl RotaryEmbedding {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
- let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
- let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
- let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
- let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
+ let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
+ let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
diff --git a/candle-transformers/src/models/qwen2_moe.rs b/candle-transformers/src/models/qwen2_moe.rs
index 5650e350..8d1d2f70 100644
--- a/candle-transformers/src/models/qwen2_moe.rs
+++ b/candle-transformers/src/models/qwen2_moe.rs
@@ -33,13 +33,6 @@ struct RotaryEmbedding {
cos: Tensor,
}
-fn rotate_half(xs: &Tensor) -> Result<Tensor> {
- let last_dim = xs.dim(D::Minus1)?;
- let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
- let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
- Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
-}
-
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.hidden_size / cfg.num_attention_heads;
@@ -54,7 +47,6 @@ impl RotaryEmbedding {
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
- let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
@@ -70,10 +62,8 @@ impl RotaryEmbedding {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
- let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
- let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
- let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
- let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
+ let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
+ let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
@@ -259,30 +249,28 @@ impl Module for SparseMoeBlock {
// In order to extract topk, we extract the data from the tensor and manipulate it
// directly. Maybe we will want to use some custom ops instead at some point.
- let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
+ let experts_per_tok = routing_weights
+ .arg_sort_last_dim(false)?
+ .narrow(D::Minus1, 0, self.num_experts_per_tok)?
+ .contiguous()?;
+ let routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?;
// routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
// top_x contains the row indexes to evaluate for each expert.
+ let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
+ let experts_per_tok = experts_per_tok.to_vec2::<u32>()?;
let mut top_x = vec![vec![]; self.experts.len()];
let mut selected_experts = vec![vec![]; self.experts.len()];
- for (row_idx, rw) in routing_weights.iter().enumerate() {
- let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
- dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
- let mut sum_routing_weights = 0f32;
- for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
- let expert_idx = expert_idx as usize;
- let routing_weight = rw[expert_idx];
- sum_routing_weights += routing_weight;
- top_x[expert_idx].push(row_idx as u32);
- }
- for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
- let expert_idx = expert_idx as usize;
- let routing_weight = if self.norm_topk_prob {
- rw[expert_idx] / sum_routing_weights
- } else {
- rw[expert_idx]
- };
- selected_experts[expert_idx].push(routing_weight)
+ for (row_idx, (rw, expert_idxs)) in routing_weights
+ .iter()
+ .zip(experts_per_tok.iter())
+ .enumerate()
+ {
+ let sum_rw = rw.iter().sum::<f32>();
+ for (&rw, &expert_idx) in rw.iter().zip(expert_idxs.iter()) {
+ top_x[expert_idx as usize].push(row_idx as u32);
+ let rw = if self.norm_topk_prob { rw / sum_rw } else { rw };
+ selected_experts[expert_idx as usize].push(rw)
}
}