diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-27 20:17:35 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-27 20:17:35 +0200 |
commit | 96a48e5cc42b3c94d9d9687bb29987953df36db8 (patch) | |
tree | 4f1f391e6e6a8c1b865c4ab40e67aaf84dd21499 /candle-transformers | |
parent | 6cf82fd7a34641601264ad1e0256ecadb7222474 (diff) | |
download | candle-96a48e5cc42b3c94d9d9687bb29987953df36db8.tar.gz candle-96a48e5cc42b3c94d9d9687bb29987953df36db8.tar.bz2 candle-96a48e5cc42b3c94d9d9687bb29987953df36db8.zip |
Add argsort. (#2132)
* Add the argsort cuda kernels.
* CPU version of arg-sort.
* Hook the cuda kernel + rework the cpu bits.
* Add some dedicated test.
* Working cuda kernel.
* Metal kernel.
* Metal adjustments.
* Bugfix.
* Use the fast rope in qwen.
* Rework the expert selection in qwen.
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/qwen2.rs | 14 | ||||
-rw-r--r-- | candle-transformers/src/models/qwen2_moe.rs | 50 |
2 files changed, 21 insertions, 43 deletions
diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs index 06f9069a..c9b5ae01 100644 --- a/candle-transformers/src/models/qwen2.rs +++ b/candle-transformers/src/models/qwen2.rs @@ -27,13 +27,6 @@ struct RotaryEmbedding { cos: Tensor, } -fn rotate_half(xs: &Tensor) -> Result<Tensor> { - let last_dim = xs.dim(D::Minus1)?; - let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; - let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; - Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) -} - impl RotaryEmbedding { fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> { let dim = cfg.hidden_size / cfg.num_attention_heads; @@ -48,7 +41,6 @@ impl RotaryEmbedding { .to_dtype(dtype)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; Ok(Self { sin: freqs.sin()?, cos: freqs.cos()?, @@ -64,10 +56,8 @@ impl RotaryEmbedding { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; - let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; - let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; Ok((q_embed, k_embed)) } } diff --git a/candle-transformers/src/models/qwen2_moe.rs b/candle-transformers/src/models/qwen2_moe.rs index 5650e350..8d1d2f70 100644 --- a/candle-transformers/src/models/qwen2_moe.rs +++ b/candle-transformers/src/models/qwen2_moe.rs @@ -33,13 +33,6 @@ struct RotaryEmbedding { cos: Tensor, } -fn rotate_half(xs: &Tensor) -> Result<Tensor> { - let last_dim = xs.dim(D::Minus1)?; - let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; - let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; - Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) -} - impl RotaryEmbedding { fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> { let dim = cfg.hidden_size / cfg.num_attention_heads; @@ -54,7 +47,6 @@ impl RotaryEmbedding { .to_dtype(dtype)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; Ok(Self { sin: freqs.sin()?, cos: freqs.cos()?, @@ -70,10 +62,8 @@ impl RotaryEmbedding { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; - let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; - let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; Ok((q_embed, k_embed)) } } @@ -259,30 +249,28 @@ impl Module for SparseMoeBlock { // In order to extract topk, we extract the data from the tensor and manipulate it // directly. Maybe we will want to use some custom ops instead at some point. - let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?; + let experts_per_tok = routing_weights + .arg_sort_last_dim(false)? + .narrow(D::Minus1, 0, self.num_experts_per_tok)? + .contiguous()?; + let routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?; // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) // top_x contains the row indexes to evaluate for each expert. + let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?; + let experts_per_tok = experts_per_tok.to_vec2::<u32>()?; let mut top_x = vec![vec![]; self.experts.len()]; let mut selected_experts = vec![vec![]; self.experts.len()]; - for (row_idx, rw) in routing_weights.iter().enumerate() { - let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>(); - dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize])); - let mut sum_routing_weights = 0f32; - for &expert_idx in dst.iter().take(self.num_experts_per_tok) { - let expert_idx = expert_idx as usize; - let routing_weight = rw[expert_idx]; - sum_routing_weights += routing_weight; - top_x[expert_idx].push(row_idx as u32); - } - for &expert_idx in dst.iter().take(self.num_experts_per_tok) { - let expert_idx = expert_idx as usize; - let routing_weight = if self.norm_topk_prob { - rw[expert_idx] / sum_routing_weights - } else { - rw[expert_idx] - }; - selected_experts[expert_idx].push(routing_weight) + for (row_idx, (rw, expert_idxs)) in routing_weights + .iter() + .zip(experts_per_tok.iter()) + .enumerate() + { + let sum_rw = rw.iter().sum::<f32>(); + for (&rw, &expert_idx) in rw.iter().zip(expert_idxs.iter()) { + top_x[expert_idx as usize].push(row_idx as u32); + let rw = if self.norm_topk_prob { rw / sum_rw } else { rw }; + selected_experts[expert_idx as usize].push(rw) } } |