summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-27 20:17:35 +0200
committerGitHub <noreply@github.com>2024-04-27 20:17:35 +0200
commit96a48e5cc42b3c94d9d9687bb29987953df36db8 (patch)
tree4f1f391e6e6a8c1b865c4ab40e67aaf84dd21499
parent6cf82fd7a34641601264ad1e0256ecadb7222474 (diff)
downloadcandle-96a48e5cc42b3c94d9d9687bb29987953df36db8.tar.gz
candle-96a48e5cc42b3c94d9d9687bb29987953df36db8.tar.bz2
candle-96a48e5cc42b3c94d9d9687bb29987953df36db8.zip
Add argsort. (#2132)
* Add the argsort cuda kernels. * CPU version of arg-sort. * Hook the cuda kernel + rework the cpu bits. * Add some dedicated test. * Working cuda kernel. * Metal kernel. * Metal adjustments. * Bugfix. * Use the fast rope in qwen. * Rework the expert selection in qwen.
-rw-r--r--candle-core/src/lib.rs1
-rw-r--r--candle-core/src/metal_backend/mod.rs2
-rw-r--r--candle-core/src/sort.rs222
-rw-r--r--candle-core/tests/tensor_tests.rs17
-rw-r--r--candle-kernels/src/lib.rs1
-rw-r--r--candle-kernels/src/sort.cu88
-rw-r--r--candle-metal-kernels/src/lib.rs40
-rw-r--r--candle-metal-kernels/src/quantized.metal1
-rw-r--r--candle-metal-kernels/src/sort.metal97
-rw-r--r--candle-transformers/src/models/qwen2.rs14
-rw-r--r--candle-transformers/src/models/qwen2_moe.rs50
11 files changed, 489 insertions, 44 deletions
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index bafad1b6..44788ddc 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -63,6 +63,7 @@ pub mod quantized;
pub mod safetensors;
pub mod scalar;
pub mod shape;
+mod sort;
mod storage;
mod strided_index;
mod tensor;
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs
index 1396899b..c0f6a844 100644
--- a/candle-core/src/metal_backend/mod.rs
+++ b/candle-core/src/metal_backend/mod.rs
@@ -11,7 +11,7 @@ use std::sync::{Arc, Mutex, RwLock, TryLockError};
mod device;
pub use device::{DeviceId, MetalDevice};
-fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
+pub fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
BufferOffset {
buffer,
offset_in_bytes: l.start_offset() * dtype.size_in_bytes(),
diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs
new file mode 100644
index 00000000..bcd098e3
--- /dev/null
+++ b/candle-core/src/sort.rs
@@ -0,0 +1,222 @@
+use crate::{Result, Tensor};
+use rayon::prelude::*;
+
+#[derive(Debug, Clone, Copy)]
+struct ArgSort {
+ asc: bool,
+ last_dim: usize,
+}
+
+impl ArgSort {
+ fn asort<T: crate::WithDType>(&self, vs: &[T], layout: &crate::Layout) -> Vec<u32> {
+ #[allow(clippy::uninit_vec)]
+ // Safety: indexes are set later in the parallelized section.
+ let mut sort_indexes = unsafe {
+ let el_count = layout.shape().elem_count();
+ let mut v = Vec::with_capacity(el_count);
+ v.set_len(el_count);
+ v
+ };
+ if self.asc {
+ sort_indexes
+ .par_chunks_exact_mut(self.last_dim)
+ .zip(vs.par_chunks_exact(self.last_dim))
+ .for_each(|(indexes, vs)| {
+ indexes
+ .iter_mut()
+ .enumerate()
+ .for_each(|(i, v)| *v = i as u32);
+ indexes.sort_by(|&i, &j| {
+ vs[i as usize]
+ .partial_cmp(&vs[j as usize])
+ .unwrap_or(std::cmp::Ordering::Greater)
+ })
+ });
+ } else {
+ sort_indexes
+ .par_chunks_exact_mut(self.last_dim)
+ .zip(vs.par_chunks_exact(self.last_dim))
+ .for_each(|(indexes, vs)| {
+ indexes
+ .iter_mut()
+ .enumerate()
+ .for_each(|(i, v)| *v = i as u32);
+ indexes.sort_by(|&j, &i| {
+ vs[i as usize]
+ .partial_cmp(&vs[j as usize])
+ .unwrap_or(std::cmp::Ordering::Greater)
+ })
+ });
+ }
+ sort_indexes
+ }
+}
+
+impl crate::CustomOp1 for ArgSort {
+ fn name(&self) -> &'static str {
+ "argsort"
+ }
+
+ fn cpu_fwd(
+ &self,
+ storage: &crate::CpuStorage,
+ layout: &crate::Layout,
+ ) -> Result<(crate::CpuStorage, crate::Shape)> {
+ let sort_indexes = match storage {
+ crate::CpuStorage::U8(vs) => self.asort(vs, layout),
+ crate::CpuStorage::U32(vs) => self.asort(vs, layout),
+ crate::CpuStorage::I64(vs) => self.asort(vs, layout),
+ crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
+ crate::CpuStorage::F16(vs) => self.asort(vs, layout),
+ crate::CpuStorage::F32(vs) => self.asort(vs, layout),
+ crate::CpuStorage::F64(vs) => self.asort(vs, layout),
+ };
+ let sort_indexes = crate::CpuStorage::U32(sort_indexes);
+ Ok((sort_indexes, layout.shape().into()))
+ }
+
+ #[cfg(feature = "cuda")]
+ fn cuda_fwd(
+ &self,
+ storage: &crate::CudaStorage,
+ layout: &crate::Layout,
+ ) -> Result<(crate::CudaStorage, crate::Shape)> {
+ use crate::cuda_backend::cudarc::driver::{
+ CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
+ };
+ use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr};
+ use crate::{CudaDevice, WithDType};
+
+ impl Map1Any for ArgSort {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
+ &self,
+ src: &CudaSlice<T>,
+ dev: &CudaDevice,
+ layout: &crate::Layout,
+ _wrap: W,
+ ) -> Result<S> {
+ let slice = match layout.contiguous_offsets() {
+ None => crate::bail!("input has to be contiguous"),
+ Some((o1, o2)) => src.slice(o1..o2),
+ };
+ let elem_count = layout.shape().elem_count();
+ let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
+ let func = if self.asc {
+ dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
+ } else {
+ dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
+ };
+ let ncols = self.last_dim;
+ let nrows = elem_count / ncols;
+ let ncols_pad = next_power_of_2(ncols);
+ let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
+ let cfg = LaunchConfig {
+ grid_dim: (1, nrows as u32, 1),
+ block_dim: (ncols_pad as u32, 1, 1),
+ shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
+ };
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(S::U32(dst))
+ }
+ }
+
+ use crate::backend::BackendStorage;
+ let dev = storage.device();
+ let slice = self.map(&storage.slice, dev, layout)?;
+ let dst = crate::cuda_backend::CudaStorage {
+ slice,
+ device: dev.clone(),
+ };
+ Ok((dst, layout.shape().clone()))
+ }
+
+ #[cfg(feature = "metal")]
+ fn metal_fwd(
+ &self,
+ storage: &crate::MetalStorage,
+ layout: &crate::Layout,
+ ) -> Result<(crate::MetalStorage, crate::Shape)> {
+ use crate::backend::BackendStorage;
+ use crate::DType;
+
+ let name = {
+ if self.asc {
+ match storage.dtype() {
+ DType::BF16 => "asort_asc_bf16",
+ DType::F16 => "asort_asc_f16",
+ DType::F32 => "asort_asc_f32",
+ DType::F64 => "asort_asc_f64",
+ DType::U8 => "asort_asc_u8",
+ DType::U32 => "asort_asc_u32",
+ DType::I64 => "asort_asc_i64",
+ }
+ } else {
+ match storage.dtype() {
+ DType::BF16 => "asort_desc_bf16",
+ DType::F16 => "asort_desc_f16",
+ DType::F32 => "asort_desc_f32",
+ DType::F64 => "asort_desc_f64",
+ DType::U8 => "asort_desc_u8",
+ DType::U32 => "asort_desc_u32",
+ DType::I64 => "asort_desc_i64",
+ }
+ }
+ };
+ let device = storage.device();
+ let kernels = device.kernels();
+ let command_buffer = device.command_buffer()?;
+ let el = layout.shape().elem_count();
+ let ncols = self.last_dim;
+ let nrows = el / ncols;
+ let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
+ let dst = device.new_buffer(el, DType::U32, "asort")?;
+ let mut ncols_pad = 1;
+ while ncols_pad < ncols {
+ ncols_pad *= 2;
+ }
+ candle_metal_kernels::call_arg_sort(
+ device.metal_device(),
+ &command_buffer,
+ kernels,
+ &name,
+ nrows,
+ ncols,
+ ncols_pad,
+ src,
+ &dst,
+ )
+ .map_err(crate::Error::wrap)?;
+ let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32);
+ Ok((dst, layout.shape().clone()))
+ }
+}
+
+#[allow(unused)]
+fn next_power_of_2(x: usize) -> usize {
+ let mut n = 1;
+ while n < x {
+ n *= 2
+ }
+ n
+}
+
+impl Tensor {
+ /// Returns the indices that sort the tensor along the last dimension.
+ ///
+ /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
+ /// descending order. The sort is unstable so there is no guarantees on the final order when it
+ /// comes to ties.
+ pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
+ if !self.is_contiguous() {
+ return Err(crate::Error::RequiresContiguous {
+ op: "arg_sort_last_dim",
+ });
+ }
+ let last_dim = match self.dims().last() {
+ None => crate::bail!("empty last-dim in arg-sort"),
+ Some(last_dim) => *last_dim,
+ };
+ // No need for a backward pass for arg sort.
+ self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
+ }
+}
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index e9b1b367..4971f337 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -96,6 +96,22 @@ fn clamp(device: &Device) -> Result<()> {
Ok(())
}
+fn asort(device: &Device) -> Result<()> {
+ let data = &[[3f32, 1., 4., 1.1, 5.], [2.1, 1., 7., 8., 2.]];
+ let tensor = Tensor::new(data, device)?;
+ let indexes = tensor.arg_sort_last_dim(true)?;
+ assert_eq!(
+ indexes.to_vec2::<u32>()?,
+ [[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
+ );
+ let indexes = tensor.arg_sort_last_dim(false)?;
+ assert_eq!(
+ indexes.to_vec2::<u32>()?,
+ [[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
+ );
+ Ok(())
+}
+
fn unary_op(device: &Device) -> Result<()> {
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
let tensor = Tensor::new(data, device)?;
@@ -1151,6 +1167,7 @@ test_device!(
);
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
+test_device!(asort, asort_cpu, asort_gpu, asort_metal);
test_device!(var, var_cpu, var_gpu, var_metal);
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
diff --git a/candle-kernels/src/lib.rs b/candle-kernels/src/lib.rs
index dc1195cb..1c73d6b7 100644
--- a/candle-kernels/src/lib.rs
+++ b/candle-kernels/src/lib.rs
@@ -6,5 +6,6 @@ pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
+pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
diff --git a/candle-kernels/src/sort.cu b/candle-kernels/src/sort.cu
new file mode 100644
index 00000000..08f1f9fc
--- /dev/null
+++ b/candle-kernels/src/sort.cu
@@ -0,0 +1,88 @@
+// Adapted from https://github.com/ggerganov/llama.cpp/blob/master/ggml-cuda/argsort.cu
+#define SORT_ORDER_ASC 1
+#define SORT_ORDER_DESC 0
+#include "cuda_utils.cuh"
+#include<stdint.h>
+
+template<typename T>
+static inline __device__ void ggml_cuda_swap(T & a, T & b) {
+ T tmp = a;
+ a = b;
+ b = tmp;
+}
+
+template<int order, typename T>
+static __device__ void k_argsort(const T * x, uint32_t * dst, const int ncols, int ncols_pad) {
+ // bitonic sort
+ int col = threadIdx.x;
+ int row = blockIdx.y;
+
+ if (col >= ncols_pad) {
+ return;
+ }
+
+ const T * x_row = x + row * ncols;
+ extern __shared__ int dst_row[];
+
+ // initialize indices
+ dst_row[col] = col;
+
+ __syncthreads();
+
+ for (int k = 2; k <= ncols_pad; k *= 2) {
+ for (int j = k / 2; j > 0; j /= 2) {
+ int ixj = col ^ j;
+ if (ixj > col) {
+ if ((col & k) == 0) {
+ if (dst_row[col] >= ncols ||
+ (dst_row[ixj] < ncols && (order == SORT_ORDER_ASC ?
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+ ) {
+ ggml_cuda_swap(dst_row[col], dst_row[ixj]);
+ }
+ } else {
+ if (dst_row[ixj] >= ncols ||
+ (dst_row[col] < ncols && (order == SORT_ORDER_ASC ?
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+ ) {
+ ggml_cuda_swap(dst_row[col], dst_row[ixj]);
+ }
+ }
+ }
+ __syncthreads();
+ }
+ }
+
+ // copy the result to dst without the padding
+ if (col < ncols) {
+ dst[row * ncols + col] = dst_row[col];
+ }
+}
+
+#define ASORT_OP(TYPENAME, RUST_NAME) \
+extern "C" __global__ void asort_asc_##RUST_NAME( \
+ const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \
+) { \
+ k_argsort<SORT_ORDER_ASC>(x, dst, ncols, ncols_pad); \
+} \
+extern "C" __global__ void asort_desc_##RUST_NAME( \
+ const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \
+) { \
+ k_argsort<SORT_ORDER_DESC>(x, dst, ncols, ncols_pad); \
+} \
+
+#if __CUDA_ARCH__ >= 800
+ASORT_OP(__nv_bfloat16, bf16)
+#endif
+
+#if __CUDA_ARCH__ >= 530
+ASORT_OP(__half, f16)
+#endif
+
+ASORT_OP(float, f32)
+ASORT_OP(double, f64)
+ASORT_OP(uint8_t, u8)
+ASORT_OP(uint32_t, u32)
+ASORT_OP(int64_t, i64)
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 10f942b4..8e075d5a 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -21,6 +21,7 @@ const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");
+const SORT: &str = include_str!("sort.metal");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
@@ -35,6 +36,7 @@ pub enum Source {
Conv,
Random,
Quantized,
+ Sort,
}
pub mod copy2d {
@@ -197,6 +199,7 @@ impl Kernels {
Source::Conv => CONV,
Source::Random => RANDOM,
Source::Quantized => QUANTIZED,
+ Source::Sort => SORT,
Source::Mfa => panic!("Invalid lib"),
}
}
@@ -2048,5 +2051,42 @@ pub fn call_conv_transpose2d(
Ok(())
}
+#[allow(clippy::too_many_arguments)]
+pub fn call_arg_sort(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ nrows: usize,
+ ncols: usize,
+ ncols_pad: usize,
+ src: BufferOffset,
+ dst: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
+
+ let thread_group_count = MTLSize {
+ width: 1,
+ height: nrows as u64,
+ depth: 1,
+ };
+ let thread_group_size = MTLSize {
+ width: ncols_pad as u64,
+ height: 1,
+ depth: 1,
+ };
+
+ encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(dst, metal::MTLResourceUsage::Write);
+ encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
#[cfg(test)]
mod tests;
diff --git a/candle-metal-kernels/src/quantized.metal b/candle-metal-kernels/src/quantized.metal
index 9aa7b502..fef6ac54 100644
--- a/candle-metal-kernels/src/quantized.metal
+++ b/candle-metal-kernels/src/quantized.metal
@@ -1,3 +1,4 @@
+// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal
#include <metal_stdlib>
using namespace metal;
diff --git a/candle-metal-kernels/src/sort.metal b/candle-metal-kernels/src/sort.metal
new file mode 100644
index 00000000..d71ab822
--- /dev/null
+++ b/candle-metal-kernels/src/sort.metal
@@ -0,0 +1,97 @@
+// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal
+#include <metal_stdlib>
+using namespace metal;
+
+#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
+#define SORT_ASC 1
+#define SORT_DESC 0
+
+template<int order, typename T>
+METAL_FUNC void argsort(
+ device const T * x,
+ device uint32_t * dst,
+ constant int64_t & ncols,
+ constant int64_t & ncols_pad,
+ threadgroup uint32_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
+ int col = tpitg[0];
+ int row = tgpig[1];
+
+ if (col >= ncols_pad) return;
+
+ device const T * x_row = x + row * ncols;
+ threadgroup uint32_t * dst_row = shared_values;
+
+ // initialize indices
+ dst_row[col] = col;
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (int k = 2; k <= ncols_pad; k *= 2) {
+ for (int j = k / 2; j > 0; j /= 2) {
+ int ixj = col ^ j;
+ if (ixj > col) {
+ if ((col & k) == 0) {
+ if (dst_row[col] >= ncols ||
+ (dst_row[ixj] < ncols && (order == SORT_ASC ?
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+ ) {
+ SWAP(dst_row[col], dst_row[ixj]);
+ }
+ } else {
+ if (dst_row[ixj] >= ncols ||
+ (dst_row[col] < ncols && (order == SORT_ASC ?
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+ ) {
+ SWAP(dst_row[col], dst_row[ixj]);
+ }
+ }
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ }
+
+ // copy the result to dst without the padding
+ if (col < ncols) {
+ dst[row * ncols + col] = dst_row[col];
+ }
+}
+
+#define ARGSORT(T, RUST_T) \
+kernel void asort_asc_##RUST_T( \
+ device const T * x, \
+ device uint32_t * dst, \
+ constant int64_t & ncols, \
+ constant int64_t & ncols_pad, \
+ threadgroup uint32_t * shared_values [[threadgroup(0)]], \
+ uint3 tgpig[[threadgroup_position_in_grid]], \
+ uint3 tpitg[[thread_position_in_threadgroup]] \
+) { \
+ argsort<SORT_ASC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
+} \
+kernel void asort_desc_##RUST_T( \
+ device const T * x, \
+ device uint32_t * dst, \
+ constant int64_t & ncols, \
+ constant int64_t & ncols_pad, \
+ threadgroup uint32_t * shared_values [[threadgroup(0)]], \
+ uint3 tgpig[[threadgroup_position_in_grid]], \
+ uint3 tpitg[[thread_position_in_threadgroup]] \
+) { \
+ argsort<SORT_DESC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
+} \
+
+ARGSORT(float, f32)
+ARGSORT(half, f16)
+ARGSORT(uint8_t, u8)
+ARGSORT(uint32_t, u32)
+
+#if __METAL_VERSION__ >= 220
+ARGSORT(int64_t, i64)
+#endif
+#if defined(__HAVE_BFLOAT__)
+ARGSORT(bfloat, bf16)
+#endif
diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs
index 06f9069a..c9b5ae01 100644
--- a/candle-transformers/src/models/qwen2.rs
+++ b/candle-transformers/src/models/qwen2.rs
@@ -27,13 +27,6 @@ struct RotaryEmbedding {
cos: Tensor,
}
-fn rotate_half(xs: &Tensor) -> Result<Tensor> {
- let last_dim = xs.dim(D::Minus1)?;
- let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
- let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
- Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
-}
-
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.hidden_size / cfg.num_attention_heads;
@@ -48,7 +41,6 @@ impl RotaryEmbedding {
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
- let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
@@ -64,10 +56,8 @@ impl RotaryEmbedding {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
- let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
- let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
- let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
- let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
+ let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
+ let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
diff --git a/candle-transformers/src/models/qwen2_moe.rs b/candle-transformers/src/models/qwen2_moe.rs
index 5650e350..8d1d2f70 100644
--- a/candle-transformers/src/models/qwen2_moe.rs
+++ b/candle-transformers/src/models/qwen2_moe.rs
@@ -33,13 +33,6 @@ struct RotaryEmbedding {
cos: Tensor,
}
-fn rotate_half(xs: &Tensor) -> Result<Tensor> {
- let last_dim = xs.dim(D::Minus1)?;
- let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
- let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
- Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
-}
-
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.hidden_size / cfg.num_attention_heads;
@@ -54,7 +47,6 @@ impl RotaryEmbedding {
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
- let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
@@ -70,10 +62,8 @@ impl RotaryEmbedding {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
- let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
- let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
- let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
- let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
+ let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
+ let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
@@ -259,30 +249,28 @@ impl Module for SparseMoeBlock {
// In order to extract topk, we extract the data from the tensor and manipulate it
// directly. Maybe we will want to use some custom ops instead at some point.
- let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
+ let experts_per_tok = routing_weights
+ .arg_sort_last_dim(false)?
+ .narrow(D::Minus1, 0, self.num_experts_per_tok)?
+ .contiguous()?;
+ let routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?;
// routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
// top_x contains the row indexes to evaluate for each expert.
+ let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
+ let experts_per_tok = experts_per_tok.to_vec2::<u32>()?;
let mut top_x = vec![vec![]; self.experts.len()];
let mut selected_experts = vec![vec![]; self.experts.len()];
- for (row_idx, rw) in routing_weights.iter().enumerate() {
- let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
- dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
- let mut sum_routing_weights = 0f32;
- for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
- let expert_idx = expert_idx as usize;
- let routing_weight = rw[expert_idx];
- sum_routing_weights += routing_weight;
- top_x[expert_idx].push(row_idx as u32);
- }
- for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
- let expert_idx = expert_idx as usize;
- let routing_weight = if self.norm_topk_prob {
- rw[expert_idx] / sum_routing_weights
- } else {
- rw[expert_idx]
- };
- selected_experts[expert_idx].push(routing_weight)
+ for (row_idx, (rw, expert_idxs)) in routing_weights
+ .iter()
+ .zip(experts_per_tok.iter())
+ .enumerate()
+ {
+ let sum_rw = rw.iter().sum::<f32>();
+ for (&rw, &expert_idx) in rw.iter().zip(expert_idxs.iter()) {
+ top_x[expert_idx as usize].push(row_idx as u32);
+ let rw = if self.norm_topk_prob { rw / sum_rw } else { rw };
+ selected_experts[expert_idx as usize].push(rw)
}
}