diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-05 08:32:58 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-05 08:32:58 +0200 |
commit | 2ac302a5d170953a1d2fe850645563fc55d1567f (patch) | |
tree | b35b32efe5c8eac25a9b5681fb0778ef84e57d0e /candle-nn | |
parent | ace282e5c2ef24ca2fb90683babb852936d4df17 (diff) | |
download | candle-2ac302a5d170953a1d2fe850645563fc55d1567f.tar.gz candle-2ac302a5d170953a1d2fe850645563fc55d1567f.tar.bz2 candle-2ac302a5d170953a1d2fe850645563fc55d1567f.zip |
Add the rope THD kernel. (#2014)
* Add the rope THD kernel.
* Cuda kernel for rope-thd.
* Add the metal kernels.
* Add a dedicated test.
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/rotary_emb.rs | 231 | ||||
-rw-r--r-- | candle-nn/tests/ops.rs | 31 |
2 files changed, 262 insertions, 0 deletions
diff --git a/candle-nn/src/rotary_emb.rs b/candle-nn/src/rotary_emb.rs index c2b41482..1084cfb5 100644 --- a/candle-nn/src/rotary_emb.rs +++ b/candle-nn/src/rotary_emb.rs @@ -497,3 +497,234 @@ pub fn rope_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> { let sin = sin.unsqueeze(0)?.unsqueeze(0)?; x.broadcast_mul(&cos)? + rotate_half(x)?.broadcast_mul(&sin)? } + +/// T (seqlen)/H (num-heads)/D (head-dim) contiguous variant of rope embeddings. +#[derive(Debug, Clone)] +struct RotaryEmbThd; + +impl candle::CustomOp3 for RotaryEmbThd { + fn name(&self) -> &'static str { + "rotary-emb" + } + + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + s3: &CpuStorage, + l3: &Layout, + ) -> Result<(CpuStorage, Shape)> { + fn inner<T: candle::WithDType + num_traits::Float>( + src: &[T], + l_src: &Layout, + cos: &[T], + l_cos: &Layout, + sin: &[T], + l_sin: &Layout, + ) -> Result<(CpuStorage, Shape)> { + let src = match l_src.contiguous_offsets() { + None => candle::bail!("input src has to be contiguous"), + Some((o1, o2)) => &src[o1..o2], + }; + let cos = match l_cos.contiguous_offsets() { + None => candle::bail!("input cos has to be contiguous"), + Some((o1, o2)) => &cos[o1..o2], + }; + let sin = match l_sin.contiguous_offsets() { + None => candle::bail!("input sin has to be contiguous"), + Some((o1, o2)) => &sin[o1..o2], + }; + let (b, t, h, d) = l_src.shape().dims4()?; + let el_count = b * h * t * d; + let mut dst = vec![T::zero(); el_count]; + src.par_chunks(t * h * d) + .zip(dst.par_chunks_mut(t * h * d)) + .for_each(|(src, dst)| { + for i_t in 0..t { + for i_d in 0..d / 2 { + let i_cs = i_t * (d / 2) + i_d; + for i_h in 0..h { + let i1 = i_t * h * d + i_h * d + i_d; + let i2 = i1 + d / 2; + dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs]; + dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs]; + } + } + } + }); + let storage = candle::WithDType::to_cpu_storage_owned(dst); + Ok((storage, (b, t, h, d).into())) + } + + use candle::backend::BackendStorage; + use CpuStorage::{BF16, F16, F32, F64}; + match (s1, s2, s3) { + (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3), + (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3), + _ => candle::bail!( + "unsupported dtype for rope {:?} {:?} {:?}", + s1.dtype(), + s2.dtype(), + s3.dtype() + ), + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &candle::CudaStorage, + l1: &Layout, + s2: &candle::CudaStorage, + l2: &Layout, + s3: &candle::CudaStorage, + l3: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + use candle::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use candle::cuda_backend::{kernel_name, kernels, WrapErr}; + use candle::{CudaDevice, WithDType}; + + fn inner<T: DeviceRepr + WithDType>( + src: &CudaSlice<T>, + l_src: &Layout, + cos: &CudaSlice<T>, + l_cos: &Layout, + sin: &CudaSlice<T>, + l_sin: &Layout, + dev: &CudaDevice, + ) -> Result<CudaSlice<T>> { + let src = match l_src.contiguous_offsets() { + None => candle::bail!("src input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let cos = match l_cos.contiguous_offsets() { + None => candle::bail!("cos input has to be contiguous"), + Some((o1, o2)) => cos.slice(o1..o2), + }; + let sin = match l_sin.contiguous_offsets() { + None => candle::bail!("sin input has to be contiguous"), + Some((o1, o2)) => sin.slice(o1..o2), + }; + let (b, t, h, d) = l_src.shape().dims4()?; + let el = b * h * t * d; + let cfg = LaunchConfig::for_num_elems((el / 2) as u32); + let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), kernels::REDUCE)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::<T>(el) }.w()?; + let params = ( + &src, &cos, &sin, &dst, b as u32, t as u32, h as u32, d as u32, + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } + + use candle::backend::BackendStorage; + use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64}; + let dev = s1.device(); + let slice = match (&s1.slice, &s2.slice, &s3.slice) { + (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?), + (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?), + _ => candle::bail!( + "unsupported dtype for rope {:?} {:?} {:?}", + s1.dtype(), + s2.dtype(), + s3.dtype() + ), + }; + let dst = candle::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, l1.shape().clone())) + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + src: &candle::MetalStorage, + l_src: &Layout, + cos: &candle::MetalStorage, + l_cos: &Layout, + sin: &candle::MetalStorage, + l_sin: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::backend::BackendStorage; + let device = src.device(); + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() { + candle::bail!( + "dtype mismatch in rope {:?} {:?} {:?}", + src.dtype(), + cos.dtype(), + sin.dtype() + ) + } + let name = match src.dtype() { + candle::DType::F32 => "rope_thd_f32", + candle::DType::F16 => "rope_thd_f16", + candle::DType::BF16 => "rope_thd_bf16", + dtype => candle::bail!("rope_thd is not implemented for {dtype:?}"), + }; + let (b, t, h, d) = l_src.shape().dims4()?; + let el = b * h * t * d; + let output = device.new_buffer(el, src.dtype(), "rope-thd")?; + candle_metal_kernels::call_rope_thd( + device.metal_device(), + &command_buffer, + kernels, + name, + b, + t, + h, + d, + src.buffer(), + l_src.start_offset() * src.dtype().size_in_bytes(), + cos.buffer(), + l_cos.start_offset() * cos.dtype().size_in_bytes(), + sin.buffer(), + l_sin.start_offset() * sin.dtype().size_in_bytes(), + &output, + ) + .map_err(candle::Error::wrap)?; + let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype()); + Ok((out, l_src.shape().clone())) + } +} + +pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> { + let (_b_sz, seq_len, _n_head, n_embd) = xs.dims4()?; + let (cos_seq_len, cos_n_embd) = cos.dims2()?; + let (sin_seq_len, sin_n_embd) = sin.dims2()?; + if cos_n_embd * 2 != n_embd + || sin_n_embd * 2 != n_embd + || seq_len > cos_seq_len + || seq_len > sin_seq_len + { + candle::bail!( + "inconsistent last dim size in rope {:?} {:?} {:?}", + xs.shape(), + cos.shape(), + sin.shape() + ) + } + if !xs.is_contiguous() { + candle::bail!("xs has to be contiguous in rope") + } + if !cos.is_contiguous() { + candle::bail!("cos has to be contiguous in rope") + } + if !sin.is_contiguous() { + candle::bail!("sin has to be contiguous in rope") + } + xs.apply_op3_no_bwd(cos, sin, &RotaryEmbThd) +} diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs index 20a66e75..24a49d06 100644 --- a/candle-nn/tests/ops.rs +++ b/candle-nn/tests/ops.rs @@ -140,7 +140,38 @@ fn rope(device: &Device) -> Result<()> { Ok(()) } +fn rope_thd(device: &Device) -> Result<()> { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16); + let el_count = b_size * num_head * seq_len * head_dim; + let mut rng = StdRng::seed_from_u64(299792458); + let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect(); + let cos: Vec<f32> = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::<f32>()) + .collect(); + let sin: Vec<f32> = (0..seq_len * head_dim / 2) + .map(|_| rng.gen::<f32>()) + .collect(); + let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?; + let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?; + let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?; + let rope1 = { + let src = src.transpose(1, 2)?.contiguous()?; + candle_nn::rotary_emb::rope_thd(&src, &cos, &sin)?.transpose(1, 2)? + }; + let rope2 = candle_nn::rotary_emb::rope_slow(&src, &cos, &sin)?; + let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?; + if device.is_cpu() { + assert_eq!(sum_diff, 0.); + } else { + assert!(sum_diff < 1e-4); + } + Ok(()) +} + test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal); test_device!(rope, rope_cpu, rope_gpu, rope_metal); +test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal); test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal); test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal); |