summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-05 08:32:58 +0200
committerGitHub <noreply@github.com>2024-04-05 08:32:58 +0200
commit2ac302a5d170953a1d2fe850645563fc55d1567f (patch)
treeb35b32efe5c8eac25a9b5681fb0778ef84e57d0e /candle-nn
parentace282e5c2ef24ca2fb90683babb852936d4df17 (diff)
downloadcandle-2ac302a5d170953a1d2fe850645563fc55d1567f.tar.gz
candle-2ac302a5d170953a1d2fe850645563fc55d1567f.tar.bz2
candle-2ac302a5d170953a1d2fe850645563fc55d1567f.zip
Add the rope THD kernel. (#2014)
* Add the rope THD kernel. * Cuda kernel for rope-thd. * Add the metal kernels. * Add a dedicated test.
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/rotary_emb.rs231
-rw-r--r--candle-nn/tests/ops.rs31
2 files changed, 262 insertions, 0 deletions
diff --git a/candle-nn/src/rotary_emb.rs b/candle-nn/src/rotary_emb.rs
index c2b41482..1084cfb5 100644
--- a/candle-nn/src/rotary_emb.rs
+++ b/candle-nn/src/rotary_emb.rs
@@ -497,3 +497,234 @@ pub fn rope_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let sin = sin.unsqueeze(0)?.unsqueeze(0)?;
x.broadcast_mul(&cos)? + rotate_half(x)?.broadcast_mul(&sin)?
}
+
+/// T (seqlen)/H (num-heads)/D (head-dim) contiguous variant of rope embeddings.
+#[derive(Debug, Clone)]
+struct RotaryEmbThd;
+
+impl candle::CustomOp3 for RotaryEmbThd {
+ fn name(&self) -> &'static str {
+ "rotary-emb"
+ }
+
+ fn cpu_fwd(
+ &self,
+ s1: &CpuStorage,
+ l1: &Layout,
+ s2: &CpuStorage,
+ l2: &Layout,
+ s3: &CpuStorage,
+ l3: &Layout,
+ ) -> Result<(CpuStorage, Shape)> {
+ fn inner<T: candle::WithDType + num_traits::Float>(
+ src: &[T],
+ l_src: &Layout,
+ cos: &[T],
+ l_cos: &Layout,
+ sin: &[T],
+ l_sin: &Layout,
+ ) -> Result<(CpuStorage, Shape)> {
+ let src = match l_src.contiguous_offsets() {
+ None => candle::bail!("input src has to be contiguous"),
+ Some((o1, o2)) => &src[o1..o2],
+ };
+ let cos = match l_cos.contiguous_offsets() {
+ None => candle::bail!("input cos has to be contiguous"),
+ Some((o1, o2)) => &cos[o1..o2],
+ };
+ let sin = match l_sin.contiguous_offsets() {
+ None => candle::bail!("input sin has to be contiguous"),
+ Some((o1, o2)) => &sin[o1..o2],
+ };
+ let (b, t, h, d) = l_src.shape().dims4()?;
+ let el_count = b * h * t * d;
+ let mut dst = vec![T::zero(); el_count];
+ src.par_chunks(t * h * d)
+ .zip(dst.par_chunks_mut(t * h * d))
+ .for_each(|(src, dst)| {
+ for i_t in 0..t {
+ for i_d in 0..d / 2 {
+ let i_cs = i_t * (d / 2) + i_d;
+ for i_h in 0..h {
+ let i1 = i_t * h * d + i_h * d + i_d;
+ let i2 = i1 + d / 2;
+ dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
+ dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
+ }
+ }
+ }
+ });
+ let storage = candle::WithDType::to_cpu_storage_owned(dst);
+ Ok((storage, (b, t, h, d).into()))
+ }
+
+ use candle::backend::BackendStorage;
+ use CpuStorage::{BF16, F16, F32, F64};
+ match (s1, s2, s3) {
+ (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
+ (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
+ (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
+ (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
+ _ => candle::bail!(
+ "unsupported dtype for rope {:?} {:?} {:?}",
+ s1.dtype(),
+ s2.dtype(),
+ s3.dtype()
+ ),
+ }
+ }
+
+ #[cfg(feature = "cuda")]
+ fn cuda_fwd(
+ &self,
+ s1: &candle::CudaStorage,
+ l1: &Layout,
+ s2: &candle::CudaStorage,
+ l2: &Layout,
+ s3: &candle::CudaStorage,
+ l3: &Layout,
+ ) -> Result<(candle::CudaStorage, Shape)> {
+ use candle::cuda_backend::cudarc::driver::{
+ CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
+ };
+ use candle::cuda_backend::{kernel_name, kernels, WrapErr};
+ use candle::{CudaDevice, WithDType};
+
+ fn inner<T: DeviceRepr + WithDType>(
+ src: &CudaSlice<T>,
+ l_src: &Layout,
+ cos: &CudaSlice<T>,
+ l_cos: &Layout,
+ sin: &CudaSlice<T>,
+ l_sin: &Layout,
+ dev: &CudaDevice,
+ ) -> Result<CudaSlice<T>> {
+ let src = match l_src.contiguous_offsets() {
+ None => candle::bail!("src input has to be contiguous"),
+ Some((o1, o2)) => src.slice(o1..o2),
+ };
+ let cos = match l_cos.contiguous_offsets() {
+ None => candle::bail!("cos input has to be contiguous"),
+ Some((o1, o2)) => cos.slice(o1..o2),
+ };
+ let sin = match l_sin.contiguous_offsets() {
+ None => candle::bail!("sin input has to be contiguous"),
+ Some((o1, o2)) => sin.slice(o1..o2),
+ };
+ let (b, t, h, d) = l_src.shape().dims4()?;
+ let el = b * h * t * d;
+ let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
+ let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), kernels::REDUCE)?;
+ // SAFETY: Set later by running the kernel.
+ let dst = unsafe { dev.alloc::<T>(el) }.w()?;
+ let params = (
+ &src, &cos, &sin, &dst, b as u32, t as u32, h as u32, d as u32,
+ );
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(dst)
+ }
+
+ use candle::backend::BackendStorage;
+ use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
+ let dev = s1.device();
+ let slice = match (&s1.slice, &s2.slice, &s3.slice) {
+ (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
+ (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
+ (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
+ (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
+ _ => candle::bail!(
+ "unsupported dtype for rope {:?} {:?} {:?}",
+ s1.dtype(),
+ s2.dtype(),
+ s3.dtype()
+ ),
+ };
+ let dst = candle::cuda_backend::CudaStorage {
+ slice,
+ device: dev.clone(),
+ };
+ Ok((dst, l1.shape().clone()))
+ }
+
+ #[cfg(feature = "metal")]
+ fn metal_fwd(
+ &self,
+ src: &candle::MetalStorage,
+ l_src: &Layout,
+ cos: &candle::MetalStorage,
+ l_cos: &Layout,
+ sin: &candle::MetalStorage,
+ l_sin: &Layout,
+ ) -> Result<(candle::MetalStorage, Shape)> {
+ use candle::backend::BackendStorage;
+ let device = src.device();
+ let command_buffer = device.command_buffer()?;
+ let kernels = device.kernels();
+ if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
+ candle::bail!(
+ "dtype mismatch in rope {:?} {:?} {:?}",
+ src.dtype(),
+ cos.dtype(),
+ sin.dtype()
+ )
+ }
+ let name = match src.dtype() {
+ candle::DType::F32 => "rope_thd_f32",
+ candle::DType::F16 => "rope_thd_f16",
+ candle::DType::BF16 => "rope_thd_bf16",
+ dtype => candle::bail!("rope_thd is not implemented for {dtype:?}"),
+ };
+ let (b, t, h, d) = l_src.shape().dims4()?;
+ let el = b * h * t * d;
+ let output = device.new_buffer(el, src.dtype(), "rope-thd")?;
+ candle_metal_kernels::call_rope_thd(
+ device.metal_device(),
+ &command_buffer,
+ kernels,
+ name,
+ b,
+ t,
+ h,
+ d,
+ src.buffer(),
+ l_src.start_offset() * src.dtype().size_in_bytes(),
+ cos.buffer(),
+ l_cos.start_offset() * cos.dtype().size_in_bytes(),
+ sin.buffer(),
+ l_sin.start_offset() * sin.dtype().size_in_bytes(),
+ &output,
+ )
+ .map_err(candle::Error::wrap)?;
+ let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype());
+ Ok((out, l_src.shape().clone()))
+ }
+}
+
+pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
+ let (_b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
+ let (cos_seq_len, cos_n_embd) = cos.dims2()?;
+ let (sin_seq_len, sin_n_embd) = sin.dims2()?;
+ if cos_n_embd * 2 != n_embd
+ || sin_n_embd * 2 != n_embd
+ || seq_len > cos_seq_len
+ || seq_len > sin_seq_len
+ {
+ candle::bail!(
+ "inconsistent last dim size in rope {:?} {:?} {:?}",
+ xs.shape(),
+ cos.shape(),
+ sin.shape()
+ )
+ }
+ if !xs.is_contiguous() {
+ candle::bail!("xs has to be contiguous in rope")
+ }
+ if !cos.is_contiguous() {
+ candle::bail!("cos has to be contiguous in rope")
+ }
+ if !sin.is_contiguous() {
+ candle::bail!("sin has to be contiguous in rope")
+ }
+ xs.apply_op3_no_bwd(cos, sin, &RotaryEmbThd)
+}
diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs
index 20a66e75..24a49d06 100644
--- a/candle-nn/tests/ops.rs
+++ b/candle-nn/tests/ops.rs
@@ -140,7 +140,38 @@ fn rope(device: &Device) -> Result<()> {
Ok(())
}
+fn rope_thd(device: &Device) -> Result<()> {
+ use rand::{rngs::StdRng, Rng, SeedableRng};
+
+ let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
+ let el_count = b_size * num_head * seq_len * head_dim;
+ let mut rng = StdRng::seed_from_u64(299792458);
+ let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
+ let cos: Vec<f32> = (0..seq_len * head_dim / 2)
+ .map(|_| rng.gen::<f32>())
+ .collect();
+ let sin: Vec<f32> = (0..seq_len * head_dim / 2)
+ .map(|_| rng.gen::<f32>())
+ .collect();
+ let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
+ let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
+ let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?;
+ let rope1 = {
+ let src = src.transpose(1, 2)?.contiguous()?;
+ candle_nn::rotary_emb::rope_thd(&src, &cos, &sin)?.transpose(1, 2)?
+ };
+ let rope2 = candle_nn::rotary_emb::rope_slow(&src, &cos, &sin)?;
+ let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
+ if device.is_cpu() {
+ assert_eq!(sum_diff, 0.);
+ } else {
+ assert!(sum_diff < 1e-4);
+ }
+ Ok(())
+}
+
test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal);
test_device!(rope, rope_cpu, rope_gpu, rope_metal);
+test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);