summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-24 22:48:52 +0100
committerGitHub <noreply@github.com>2024-03-24 22:48:52 +0100
commit1b98f84a2baa23192b97e36131011da658bfa1c2 (patch)
tree92c4e9e8a263edfc8d3fedeab2cc02271d87d51e /candle-nn
parentcf7d7fcf2f20c24aae633483c3a107c1219a7f9a (diff)
downloadcandle-1b98f84a2baa23192b97e36131011da658bfa1c2.tar.gz
candle-1b98f84a2baa23192b97e36131011da658bfa1c2.tar.bz2
candle-1b98f84a2baa23192b97e36131011da658bfa1c2.zip
Fast kernels for rotary embeddings. (#1928)
* Fast kernels for rotary embeddings. * Add a test for the fast CPU kernel. * Rope cuda bindings. * Cuda kernel. * Metal kernel (part 1). * Cuda kernels. * Finish the metal kernel. * Use the new kernels in the quantized example. * Fix warning.
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/Cargo.toml1
-rw-r--r--candle-nn/src/lib.rs1
-rw-r--r--candle-nn/src/rotary_emb.rs247
-rw-r--r--candle-nn/tests/ops.rs28
4 files changed, 277 insertions, 0 deletions
diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml
index 214e8a59..3408dae3 100644
--- a/candle-nn/Cargo.toml
+++ b/candle-nn/Cargo.toml
@@ -25,6 +25,7 @@ candle-metal-kernels = { workspace = true, optional = true }
[dev-dependencies]
anyhow = { workspace = true }
clap = { workspace = true }
+rand = { workspace = true }
[features]
default = []
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index 1bcb78d9..5c0fbb37 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -12,6 +12,7 @@ pub mod loss;
pub mod ops;
pub mod optim;
pub mod rnn;
+pub mod rotary_emb;
pub mod sequential;
pub mod var_builder;
pub mod var_map;
diff --git a/candle-nn/src/rotary_emb.rs b/candle-nn/src/rotary_emb.rs
new file mode 100644
index 00000000..20545b8d
--- /dev/null
+++ b/candle-nn/src/rotary_emb.rs
@@ -0,0 +1,247 @@
+use candle::{CpuStorage, Layout, Result, Shape, Tensor, D};
+use rayon::prelude::*;
+
+/// Interleaved variant of rotary embeddings.
+/// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
+/// The resulting y0 and y1 are also interleaved with:
+/// y0 = x0*cos - x1*sin
+/// y1 = x0*sin + x1*cos
+#[derive(Debug, Clone)]
+struct RotaryEmbI;
+
+impl candle::CustomOp3 for RotaryEmbI {
+ fn name(&self) -> &'static str {
+ "rotary-emb-int"
+ }
+
+ fn cpu_fwd(
+ &self,
+ s1: &CpuStorage,
+ l1: &Layout,
+ s2: &CpuStorage,
+ l2: &Layout,
+ s3: &CpuStorage,
+ l3: &Layout,
+ ) -> Result<(CpuStorage, Shape)> {
+ fn inner<T: candle::WithDType + num_traits::Float>(
+ src: &[T],
+ l_src: &Layout,
+ cos: &[T],
+ l_cos: &Layout,
+ sin: &[T],
+ l_sin: &Layout,
+ ) -> Result<(CpuStorage, Shape)> {
+ let src = match l_src.contiguous_offsets() {
+ None => candle::bail!("input src has to be contiguous"),
+ Some((o1, o2)) => &src[o1..o2],
+ };
+ let cos = match l_cos.contiguous_offsets() {
+ None => candle::bail!("input cos has to be contiguous"),
+ Some((o1, o2)) => &cos[o1..o2],
+ };
+ let sin = match l_sin.contiguous_offsets() {
+ None => candle::bail!("input sin has to be contiguous"),
+ Some((o1, o2)) => &sin[o1..o2],
+ };
+ let (b, h, t, d) = l_src.shape().dims4()?;
+ let el_count = b * h * t * d;
+ let mut dst = vec![T::zero(); el_count];
+ src.par_chunks(t * d)
+ .zip(dst.par_chunks_mut(t * d))
+ .for_each(|(src, dst)| {
+ for i_over_2 in 0..t * d / 2 {
+ let i = 2 * i_over_2;
+ dst[i] = src[i] * cos[i_over_2] - src[i + 1] * sin[i_over_2];
+ dst[i + 1] = src[i] * sin[i_over_2] + src[i + 1] * cos[i_over_2];
+ }
+ });
+ let storage = candle::WithDType::to_cpu_storage_owned(dst);
+ Ok((storage, (b, h, t, d).into()))
+ }
+
+ use candle::backend::BackendStorage;
+ use CpuStorage::{BF16, F16, F32, F64};
+ match (s1, s2, s3) {
+ (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
+ (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
+ (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
+ (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
+ _ => candle::bail!(
+ "unsupported dtype for rope {:?} {:?} {:?}",
+ s1.dtype(),
+ s2.dtype(),
+ s3.dtype()
+ ),
+ }
+ }
+
+ #[cfg(feature = "cuda")]
+ fn cuda_fwd(
+ &self,
+ s1: &candle::CudaStorage,
+ l1: &Layout,
+ s2: &candle::CudaStorage,
+ l2: &Layout,
+ s3: &candle::CudaStorage,
+ l3: &Layout,
+ ) -> Result<(candle::CudaStorage, Shape)> {
+ use candle::cuda_backend::cudarc::driver::{
+ CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
+ };
+ use candle::cuda_backend::{kernel_name, kernels, WrapErr};
+ use candle::{CudaDevice, WithDType};
+
+ fn inner<T: DeviceRepr + WithDType>(
+ src: &CudaSlice<T>,
+ l_src: &Layout,
+ cos: &CudaSlice<T>,
+ l_cos: &Layout,
+ sin: &CudaSlice<T>,
+ l_sin: &Layout,
+ dev: &CudaDevice,
+ ) -> Result<CudaSlice<T>> {
+ let src = match l_src.contiguous_offsets() {
+ None => candle::bail!("src input has to be contiguous"),
+ Some((o1, o2)) => src.slice(o1..o2),
+ };
+ let cos = match l_cos.contiguous_offsets() {
+ None => candle::bail!("cos input has to be contiguous"),
+ Some((o1, o2)) => cos.slice(o1..o2),
+ };
+ let sin = match l_sin.contiguous_offsets() {
+ None => candle::bail!("sin input has to be contiguous"),
+ Some((o1, o2)) => sin.slice(o1..o2),
+ };
+ let (b, h, t, d) = l_src.shape().dims4()?;
+ let el = b * h * t * d;
+ let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
+ let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), kernels::REDUCE)?;
+ // SAFETY: Set later by running the kernel.
+ let dst = unsafe { dev.alloc::<T>(el) }.w()?;
+ let params = (&src, &cos, &sin, &dst, (b * h) as u32, (t * d) as u32);
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(dst)
+ }
+
+ use candle::backend::BackendStorage;
+ use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
+ let dev = s1.device();
+ let slice = match (&s1.slice, &s2.slice, &s3.slice) {
+ (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
+ (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
+ (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
+ (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
+ _ => candle::bail!(
+ "unsupported dtype for rope {:?} {:?} {:?}",
+ s1.dtype(),
+ s2.dtype(),
+ s3.dtype()
+ ),
+ };
+ let dst = candle::cuda_backend::CudaStorage {
+ slice,
+ device: dev.clone(),
+ };
+ Ok((dst, l1.shape().clone()))
+ }
+
+ #[cfg(feature = "metal")]
+ fn metal_fwd(
+ &self,
+ src: &candle::MetalStorage,
+ l_src: &Layout,
+ cos: &candle::MetalStorage,
+ l_cos: &Layout,
+ sin: &candle::MetalStorage,
+ l_sin: &Layout,
+ ) -> Result<(candle::MetalStorage, Shape)> {
+ use candle::backend::BackendStorage;
+ let device = src.device();
+ let command_buffer = device.command_buffer()?;
+ let kernels = device.kernels();
+ if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
+ candle::bail!(
+ "dtype mismatch in rope-i {:?} {:?} {:?}",
+ src.dtype(),
+ cos.dtype(),
+ sin.dtype()
+ )
+ }
+ let name = match src.dtype() {
+ candle::DType::F32 => "rope_i_f32",
+ candle::DType::F16 => "rope_i_f16",
+ candle::DType::BF16 => "rope_i_bf16",
+ dtype => candle::bail!("rope-i is not implemented for {dtype:?}"),
+ };
+ let (b, h, t, d) = l_src.shape().dims4()?;
+ let el = b * h * t * d;
+ let output = device.new_buffer(el, src.dtype(), "rope-i")?;
+ candle_metal_kernels::call_rope_i(
+ device.metal_device(),
+ &command_buffer,
+ kernels,
+ name,
+ b * h,
+ t * d,
+ src.buffer(),
+ l_src.start_offset() * src.dtype().size_in_bytes(),
+ cos.buffer(),
+ l_cos.start_offset() * cos.dtype().size_in_bytes(),
+ sin.buffer(),
+ l_sin.start_offset() * sin.dtype().size_in_bytes(),
+ &output,
+ )
+ .map_err(candle::Error::wrap)?;
+ let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype());
+ Ok((out, l_src.shape().clone()))
+ }
+}
+
+pub fn rope_i(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
+ let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
+ let (cos_seq_len, cos_n_embd) = cos.dims2()?;
+ let (sin_seq_len, sin_n_embd) = cos.dims2()?;
+ if cos_n_embd * 2 != n_embd
+ || sin_n_embd * 2 != n_embd
+ || seq_len > cos_seq_len
+ || seq_len > sin_seq_len
+ {
+ candle::bail!(
+ "inconsistent last dim size in rope {:?} {:?} {:?}",
+ xs.shape(),
+ cos.shape(),
+ sin.shape()
+ )
+ }
+ if !xs.is_contiguous() {
+ candle::bail!("xs has to be contiguous in rope")
+ }
+ if !cos.is_contiguous() {
+ candle::bail!("cos has to be contiguous in rope")
+ }
+ if !sin.is_contiguous() {
+ candle::bail!("sin has to be contiguous in rope")
+ }
+ xs.apply_op3_no_bwd(cos, sin, &RotaryEmbI)
+}
+
+pub fn rope_i_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
+ let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
+ let cos = cos
+ .narrow(0, 0, seq_len)?
+ .reshape((seq_len, n_embd / 2, 1))?;
+ let sin = sin
+ .narrow(0, 0, seq_len)?
+ .reshape((seq_len, n_embd / 2, 1))?;
+ let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
+ let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
+ let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
+ let x0 = x.narrow(D::Minus1, 0, 1)?;
+ let x1 = x.narrow(D::Minus1, 1, 1)?;
+ let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
+ let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
+ let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
+ let rope = rope.flatten_from(D::Minus2)?;
+ Ok(rope)
+}
diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs
index c1e3031f..af883b85 100644
--- a/candle-nn/tests/ops.rs
+++ b/candle-nn/tests/ops.rs
@@ -86,5 +86,33 @@ fn softmax_numerical_stability() -> Result<()> {
Ok(())
}
+fn rope(device: &Device) -> Result<()> {
+ use rand::{rngs::StdRng, Rng, SeedableRng};
+
+ let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
+ let el_count = b_size * num_head * seq_len * head_dim;
+ let mut rng = StdRng::seed_from_u64(299792458);
+ let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
+ let cos: Vec<f32> = (0..seq_len * head_dim / 2)
+ .map(|_| rng.gen::<f32>())
+ .collect();
+ let sin: Vec<f32> = (0..seq_len * head_dim / 2)
+ .map(|_| rng.gen::<f32>())
+ .collect();
+ let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
+ let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
+ let sin = Tensor::from_vec(sin, (seq_len, head_dim / 2), device)?;
+ let rope1 = candle_nn::rotary_emb::rope_i(&src, &cos, &sin)?;
+ let rope2 = candle_nn::rotary_emb::rope_i_slow(&src, &cos, &sin)?;
+ let sum_diff = (rope1 - rope2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
+ if device.is_cpu() {
+ assert_eq!(sum_diff, 0.);
+ } else if device.is_cuda() {
+ assert!(sum_diff < 1e-4);
+ }
+ Ok(())
+}
+
+test_device!(rope, rope_cpu, rope_gpu, rope_metal);
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);