diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-27 20:17:35 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-27 20:17:35 +0200 |
commit | 96a48e5cc42b3c94d9d9687bb29987953df36db8 (patch) | |
tree | 4f1f391e6e6a8c1b865c4ab40e67aaf84dd21499 | |
parent | 6cf82fd7a34641601264ad1e0256ecadb7222474 (diff) | |
download | candle-96a48e5cc42b3c94d9d9687bb29987953df36db8.tar.gz candle-96a48e5cc42b3c94d9d9687bb29987953df36db8.tar.bz2 candle-96a48e5cc42b3c94d9d9687bb29987953df36db8.zip |
Add argsort. (#2132)
* Add the argsort cuda kernels.
* CPU version of arg-sort.
* Hook the cuda kernel + rework the cpu bits.
* Add some dedicated test.
* Working cuda kernel.
* Metal kernel.
* Metal adjustments.
* Bugfix.
* Use the fast rope in qwen.
* Rework the expert selection in qwen.
-rw-r--r-- | candle-core/src/lib.rs | 1 | ||||
-rw-r--r-- | candle-core/src/metal_backend/mod.rs | 2 | ||||
-rw-r--r-- | candle-core/src/sort.rs | 222 | ||||
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 17 | ||||
-rw-r--r-- | candle-kernels/src/lib.rs | 1 | ||||
-rw-r--r-- | candle-kernels/src/sort.cu | 88 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 40 | ||||
-rw-r--r-- | candle-metal-kernels/src/quantized.metal | 1 | ||||
-rw-r--r-- | candle-metal-kernels/src/sort.metal | 97 | ||||
-rw-r--r-- | candle-transformers/src/models/qwen2.rs | 14 | ||||
-rw-r--r-- | candle-transformers/src/models/qwen2_moe.rs | 50 |
11 files changed, 489 insertions, 44 deletions
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index bafad1b6..44788ddc 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -63,6 +63,7 @@ pub mod quantized; pub mod safetensors; pub mod scalar; pub mod shape; +mod sort; mod storage; mod strided_index; mod tensor; diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 1396899b..c0f6a844 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -11,7 +11,7 @@ use std::sync::{Arc, Mutex, RwLock, TryLockError}; mod device; pub use device::{DeviceId, MetalDevice}; -fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> { +pub fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> { BufferOffset { buffer, offset_in_bytes: l.start_offset() * dtype.size_in_bytes(), diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs new file mode 100644 index 00000000..bcd098e3 --- /dev/null +++ b/candle-core/src/sort.rs @@ -0,0 +1,222 @@ +use crate::{Result, Tensor}; +use rayon::prelude::*; + +#[derive(Debug, Clone, Copy)] +struct ArgSort { + asc: bool, + last_dim: usize, +} + +impl ArgSort { + fn asort<T: crate::WithDType>(&self, vs: &[T], layout: &crate::Layout) -> Vec<u32> { + #[allow(clippy::uninit_vec)] + // Safety: indexes are set later in the parallelized section. + let mut sort_indexes = unsafe { + let el_count = layout.shape().elem_count(); + let mut v = Vec::with_capacity(el_count); + v.set_len(el_count); + v + }; + if self.asc { + sort_indexes + .par_chunks_exact_mut(self.last_dim) + .zip(vs.par_chunks_exact(self.last_dim)) + .for_each(|(indexes, vs)| { + indexes + .iter_mut() + .enumerate() + .for_each(|(i, v)| *v = i as u32); + indexes.sort_by(|&i, &j| { + vs[i as usize] + .partial_cmp(&vs[j as usize]) + .unwrap_or(std::cmp::Ordering::Greater) + }) + }); + } else { + sort_indexes + .par_chunks_exact_mut(self.last_dim) + .zip(vs.par_chunks_exact(self.last_dim)) + .for_each(|(indexes, vs)| { + indexes + .iter_mut() + .enumerate() + .for_each(|(i, v)| *v = i as u32); + indexes.sort_by(|&j, &i| { + vs[i as usize] + .partial_cmp(&vs[j as usize]) + .unwrap_or(std::cmp::Ordering::Greater) + }) + }); + } + sort_indexes + } +} + +impl crate::CustomOp1 for ArgSort { + fn name(&self) -> &'static str { + "argsort" + } + + fn cpu_fwd( + &self, + storage: &crate::CpuStorage, + layout: &crate::Layout, + ) -> Result<(crate::CpuStorage, crate::Shape)> { + let sort_indexes = match storage { + crate::CpuStorage::U8(vs) => self.asort(vs, layout), + crate::CpuStorage::U32(vs) => self.asort(vs, layout), + crate::CpuStorage::I64(vs) => self.asort(vs, layout), + crate::CpuStorage::BF16(vs) => self.asort(vs, layout), + crate::CpuStorage::F16(vs) => self.asort(vs, layout), + crate::CpuStorage::F32(vs) => self.asort(vs, layout), + crate::CpuStorage::F64(vs) => self.asort(vs, layout), + }; + let sort_indexes = crate::CpuStorage::U32(sort_indexes); + Ok((sort_indexes, layout.shape().into())) + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + storage: &crate::CudaStorage, + layout: &crate::Layout, + ) -> Result<(crate::CudaStorage, crate::Shape)> { + use crate::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits, + }; + use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr}; + use crate::{CudaDevice, WithDType}; + + impl Map1Any for ArgSort { + fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>( + &self, + src: &CudaSlice<T>, + dev: &CudaDevice, + layout: &crate::Layout, + _wrap: W, + ) -> Result<S> { + let slice = match layout.contiguous_offsets() { + None => crate::bail!("input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let elem_count = layout.shape().elem_count(); + let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?; + let func = if self.asc { + dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)? + } else { + dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)? + }; + let ncols = self.last_dim; + let nrows = elem_count / ncols; + let ncols_pad = next_power_of_2(ncols); + let params = (&slice, &dst, ncols as i32, ncols_pad as i32); + let cfg = LaunchConfig { + grid_dim: (1, nrows as u32, 1), + block_dim: (ncols_pad as u32, 1, 1), + shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32, + }; + unsafe { func.launch(cfg, params) }.w()?; + Ok(S::U32(dst)) + } + } + + use crate::backend::BackendStorage; + let dev = storage.device(); + let slice = self.map(&storage.slice, dev, layout)?; + let dst = crate::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, layout.shape().clone())) + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + storage: &crate::MetalStorage, + layout: &crate::Layout, + ) -> Result<(crate::MetalStorage, crate::Shape)> { + use crate::backend::BackendStorage; + use crate::DType; + + let name = { + if self.asc { + match storage.dtype() { + DType::BF16 => "asort_asc_bf16", + DType::F16 => "asort_asc_f16", + DType::F32 => "asort_asc_f32", + DType::F64 => "asort_asc_f64", + DType::U8 => "asort_asc_u8", + DType::U32 => "asort_asc_u32", + DType::I64 => "asort_asc_i64", + } + } else { + match storage.dtype() { + DType::BF16 => "asort_desc_bf16", + DType::F16 => "asort_desc_f16", + DType::F32 => "asort_desc_f32", + DType::F64 => "asort_desc_f64", + DType::U8 => "asort_desc_u8", + DType::U32 => "asort_desc_u32", + DType::I64 => "asort_desc_i64", + } + } + }; + let device = storage.device(); + let kernels = device.kernels(); + let command_buffer = device.command_buffer()?; + let el = layout.shape().elem_count(); + let ncols = self.last_dim; + let nrows = el / ncols; + let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype()); + let dst = device.new_buffer(el, DType::U32, "asort")?; + let mut ncols_pad = 1; + while ncols_pad < ncols { + ncols_pad *= 2; + } + candle_metal_kernels::call_arg_sort( + device.metal_device(), + &command_buffer, + kernels, + &name, + nrows, + ncols, + ncols_pad, + src, + &dst, + ) + .map_err(crate::Error::wrap)?; + let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32); + Ok((dst, layout.shape().clone())) + } +} + +#[allow(unused)] +fn next_power_of_2(x: usize) -> usize { + let mut n = 1; + while n < x { + n *= 2 + } + n +} + +impl Tensor { + /// Returns the indices that sort the tensor along the last dimension. + /// + /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in + /// descending order. The sort is unstable so there is no guarantees on the final order when it + /// comes to ties. + pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> { + if !self.is_contiguous() { + return Err(crate::Error::RequiresContiguous { + op: "arg_sort_last_dim", + }); + } + let last_dim = match self.dims().last() { + None => crate::bail!("empty last-dim in arg-sort"), + Some(last_dim) => *last_dim, + }; + // No need for a backward pass for arg sort. + self.apply_op1_no_bwd(&ArgSort { asc, last_dim }) + } +} diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index e9b1b367..4971f337 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -96,6 +96,22 @@ fn clamp(device: &Device) -> Result<()> { Ok(()) } +fn asort(device: &Device) -> Result<()> { + let data = &[[3f32, 1., 4., 1.1, 5.], [2.1, 1., 7., 8., 2.]]; + let tensor = Tensor::new(data, device)?; + let indexes = tensor.arg_sort_last_dim(true)?; + assert_eq!( + indexes.to_vec2::<u32>()?, + [[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]], + ); + let indexes = tensor.arg_sort_last_dim(false)?; + assert_eq!( + indexes.to_vec2::<u32>()?, + [[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]], + ); + Ok(()) +} + fn unary_op(device: &Device) -> Result<()> { let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]]; let tensor = Tensor::new(data, device)?; @@ -1151,6 +1167,7 @@ test_device!( ); test_device!(randn, randn_cpu, randn_gpu, randn_metal); test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal); +test_device!(asort, asort_cpu, asort_gpu, asort_metal); test_device!(var, var_cpu, var_gpu, var_metal); test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal); diff --git a/candle-kernels/src/lib.rs b/candle-kernels/src/lib.rs index dc1195cb..1c73d6b7 100644 --- a/candle-kernels/src/lib.rs +++ b/candle-kernels/src/lib.rs @@ -6,5 +6,6 @@ pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx")); pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx")); pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx")); +pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx")); pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx")); pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx")); diff --git a/candle-kernels/src/sort.cu b/candle-kernels/src/sort.cu new file mode 100644 index 00000000..08f1f9fc --- /dev/null +++ b/candle-kernels/src/sort.cu @@ -0,0 +1,88 @@ +// Adapted from https://github.com/ggerganov/llama.cpp/blob/master/ggml-cuda/argsort.cu +#define SORT_ORDER_ASC 1 +#define SORT_ORDER_DESC 0 +#include "cuda_utils.cuh" +#include<stdint.h> + +template<typename T> +static inline __device__ void ggml_cuda_swap(T & a, T & b) { + T tmp = a; + a = b; + b = tmp; +} + +template<int order, typename T> +static __device__ void k_argsort(const T * x, uint32_t * dst, const int ncols, int ncols_pad) { + // bitonic sort + int col = threadIdx.x; + int row = blockIdx.y; + + if (col >= ncols_pad) { + return; + } + + const T * x_row = x + row * ncols; + extern __shared__ int dst_row[]; + + // initialize indices + dst_row[col] = col; + + __syncthreads(); + + for (int k = 2; k <= ncols_pad; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && (order == SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { + ggml_cuda_swap(dst_row[col], dst_row[ixj]); + } + } else { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && (order == SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { + ggml_cuda_swap(dst_row[col], dst_row[ixj]); + } + } + } + __syncthreads(); + } + } + + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } +} + +#define ASORT_OP(TYPENAME, RUST_NAME) \ +extern "C" __global__ void asort_asc_##RUST_NAME( \ + const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \ +) { \ + k_argsort<SORT_ORDER_ASC>(x, dst, ncols, ncols_pad); \ +} \ +extern "C" __global__ void asort_desc_##RUST_NAME( \ + const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \ +) { \ + k_argsort<SORT_ORDER_DESC>(x, dst, ncols, ncols_pad); \ +} \ + +#if __CUDA_ARCH__ >= 800 +ASORT_OP(__nv_bfloat16, bf16) +#endif + +#if __CUDA_ARCH__ >= 530 +ASORT_OP(__half, f16) +#endif + +ASORT_OP(float, f32) +ASORT_OP(double, f64) +ASORT_OP(uint8_t, u8) +ASORT_OP(uint32_t, u32) +ASORT_OP(int64_t, i64) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 10f942b4..8e075d5a 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -21,6 +21,7 @@ const REDUCE: &str = include_str!("reduce.metal"); const RANDOM: &str = include_str!("random.metal"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); const QUANTIZED: &str = include_str!("quantized.metal"); +const SORT: &str = include_str!("sort.metal"); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Source { @@ -35,6 +36,7 @@ pub enum Source { Conv, Random, Quantized, + Sort, } pub mod copy2d { @@ -197,6 +199,7 @@ impl Kernels { Source::Conv => CONV, Source::Random => RANDOM, Source::Quantized => QUANTIZED, + Source::Sort => SORT, Source::Mfa => panic!("Invalid lib"), } } @@ -2048,5 +2051,42 @@ pub fn call_conv_transpose2d( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_arg_sort( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + nrows: usize, + ncols: usize, + ncols_pad: usize, + src: BufferOffset, + dst: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Sort, name)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64)); + + let thread_group_count = MTLSize { + width: 1, + height: nrows as u64, + depth: 1, + }; + let thread_group_size = MTLSize { + width: ncols_pad as u64, + height: 1, + depth: 1, + }; + + encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(dst, metal::MTLResourceUsage::Write); + encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[cfg(test)] mod tests; diff --git a/candle-metal-kernels/src/quantized.metal b/candle-metal-kernels/src/quantized.metal index 9aa7b502..fef6ac54 100644 --- a/candle-metal-kernels/src/quantized.metal +++ b/candle-metal-kernels/src/quantized.metal @@ -1,3 +1,4 @@ +// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal #include <metal_stdlib> using namespace metal; diff --git a/candle-metal-kernels/src/sort.metal b/candle-metal-kernels/src/sort.metal new file mode 100644 index 00000000..d71ab822 --- /dev/null +++ b/candle-metal-kernels/src/sort.metal @@ -0,0 +1,97 @@ +// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal +#include <metal_stdlib> +using namespace metal; + +#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } +#define SORT_ASC 1 +#define SORT_DESC 0 + +template<int order, typename T> +METAL_FUNC void argsort( + device const T * x, + device uint32_t * dst, + constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup uint32_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { + int col = tpitg[0]; + int row = tgpig[1]; + + if (col >= ncols_pad) return; + + device const T * x_row = x + row * ncols; + threadgroup uint32_t * dst_row = shared_values; + + // initialize indices + dst_row[col] = col; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int k = 2; k <= ncols_pad; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && (order == SORT_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } else { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && (order == SORT_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } + + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } +} + +#define ARGSORT(T, RUST_T) \ +kernel void asort_asc_##RUST_T( \ + device const T * x, \ + device uint32_t * dst, \ + constant int64_t & ncols, \ + constant int64_t & ncols_pad, \ + threadgroup uint32_t * shared_values [[threadgroup(0)]], \ + uint3 tgpig[[threadgroup_position_in_grid]], \ + uint3 tpitg[[thread_position_in_threadgroup]] \ +) { \ + argsort<SORT_ASC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \ +} \ +kernel void asort_desc_##RUST_T( \ + device const T * x, \ + device uint32_t * dst, \ + constant int64_t & ncols, \ + constant int64_t & ncols_pad, \ + threadgroup uint32_t * shared_values [[threadgroup(0)]], \ + uint3 tgpig[[threadgroup_position_in_grid]], \ + uint3 tpitg[[thread_position_in_threadgroup]] \ +) { \ + argsort<SORT_DESC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \ +} \ + +ARGSORT(float, f32) +ARGSORT(half, f16) +ARGSORT(uint8_t, u8) +ARGSORT(uint32_t, u32) + +#if __METAL_VERSION__ >= 220 +ARGSORT(int64_t, i64) +#endif +#if defined(__HAVE_BFLOAT__) +ARGSORT(bfloat, bf16) +#endif diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs index 06f9069a..c9b5ae01 100644 --- a/candle-transformers/src/models/qwen2.rs +++ b/candle-transformers/src/models/qwen2.rs @@ -27,13 +27,6 @@ struct RotaryEmbedding { cos: Tensor, } -fn rotate_half(xs: &Tensor) -> Result<Tensor> { - let last_dim = xs.dim(D::Minus1)?; - let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; - let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; - Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) -} - impl RotaryEmbedding { fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> { let dim = cfg.hidden_size / cfg.num_attention_heads; @@ -48,7 +41,6 @@ impl RotaryEmbedding { .to_dtype(dtype)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; Ok(Self { sin: freqs.sin()?, cos: freqs.cos()?, @@ -64,10 +56,8 @@ impl RotaryEmbedding { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; - let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; - let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; Ok((q_embed, k_embed)) } } diff --git a/candle-transformers/src/models/qwen2_moe.rs b/candle-transformers/src/models/qwen2_moe.rs index 5650e350..8d1d2f70 100644 --- a/candle-transformers/src/models/qwen2_moe.rs +++ b/candle-transformers/src/models/qwen2_moe.rs @@ -33,13 +33,6 @@ struct RotaryEmbedding { cos: Tensor, } -fn rotate_half(xs: &Tensor) -> Result<Tensor> { - let last_dim = xs.dim(D::Minus1)?; - let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?; - let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?; - Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1) -} - impl RotaryEmbedding { fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> { let dim = cfg.hidden_size / cfg.num_attention_heads; @@ -54,7 +47,6 @@ impl RotaryEmbedding { .to_dtype(dtype)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; - let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; Ok(Self { sin: freqs.sin()?, cos: freqs.cos()?, @@ -70,10 +62,8 @@ impl RotaryEmbedding { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; - let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim) - let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?; - let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; Ok((q_embed, k_embed)) } } @@ -259,30 +249,28 @@ impl Module for SparseMoeBlock { // In order to extract topk, we extract the data from the tensor and manipulate it // directly. Maybe we will want to use some custom ops instead at some point. - let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?; + let experts_per_tok = routing_weights + .arg_sort_last_dim(false)? + .narrow(D::Minus1, 0, self.num_experts_per_tok)? + .contiguous()?; + let routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?; // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) // top_x contains the row indexes to evaluate for each expert. + let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?; + let experts_per_tok = experts_per_tok.to_vec2::<u32>()?; let mut top_x = vec![vec![]; self.experts.len()]; let mut selected_experts = vec![vec![]; self.experts.len()]; - for (row_idx, rw) in routing_weights.iter().enumerate() { - let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>(); - dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize])); - let mut sum_routing_weights = 0f32; - for &expert_idx in dst.iter().take(self.num_experts_per_tok) { - let expert_idx = expert_idx as usize; - let routing_weight = rw[expert_idx]; - sum_routing_weights += routing_weight; - top_x[expert_idx].push(row_idx as u32); - } - for &expert_idx in dst.iter().take(self.num_experts_per_tok) { - let expert_idx = expert_idx as usize; - let routing_weight = if self.norm_topk_prob { - rw[expert_idx] / sum_routing_weights - } else { - rw[expert_idx] - }; - selected_experts[expert_idx].push(routing_weight) + for (row_idx, (rw, expert_idxs)) in routing_weights + .iter() + .zip(experts_per_tok.iter()) + .enumerate() + { + let sum_rw = rw.iter().sum::<f32>(); + for (&rw, &expert_idx) in rw.iter().zip(expert_idxs.iter()) { + top_x[expert_idx as usize].push(row_idx as u32); + let rw = if self.norm_topk_prob { rw / sum_rw } else { rw }; + selected_experts[expert_idx as usize].push(rw) } } |