From 122e334d0cf9c6b56adc2f6f287617141841f636 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 29 Jun 2023 09:21:11 +0100 Subject: Simplify the pattern matching logic in the cuda backend. --- candle-core/src/cuda_backend.rs | 157 ++++++++++++++++++---------------------- 1 file changed, 72 insertions(+), 85 deletions(-) (limited to 'candle-core/src/cuda_backend.rs') diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 9d9a5f99..7dfbb468 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1,7 +1,7 @@ -use crate::{CpuStorage, DType, Layout, Shape}; +use crate::{CpuStorage, DType, Layout, Shape, WithDType}; use candle_kernels as kernels; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; -use cudarc::driver::{CudaFunction, CudaSlice, DeviceSlice, LaunchAsync, LaunchConfig}; +use cudarc::driver::{CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig}; use half::{bf16, f16}; use std::sync::Arc; @@ -243,6 +243,72 @@ enum CudaStorageSlice { F64(CudaSlice), } +trait Map1 { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result>; + + fn map(&self, s: &CudaStorageSlice, d: &CudaDevice, l: &Layout) -> Result { + let out = match s { + CudaStorageSlice::U32(s) => CudaStorageSlice::U32(self.f(s, d, l)?), + CudaStorageSlice::BF16(s) => CudaStorageSlice::BF16(self.f(s, d, l)?), + CudaStorageSlice::F16(s) => CudaStorageSlice::F16(self.f(s, d, l)?), + CudaStorageSlice::F32(s) => CudaStorageSlice::F32(self.f(s, d, l)?), + CudaStorageSlice::F64(s) => CudaStorageSlice::F64(self.f(s, d, l)?), + }; + Ok(out) + } +} + +struct Clone; +impl Map1 for Clone { + fn f( + &self, + s: &CudaSlice, + _: &CudaDevice, + _: &Layout, + ) -> Result> { + Ok(s.try_clone()?) + } +} + +struct Affine(f64, f64); + +impl Map1 for Affine { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let shape = layout.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let ds = dev.htod_copy([dims, layout.stride()].concat())?; + let src = &src.slice(layout.start_offset()..); + let kernel_name = format!("affine_{}", T::DTYPE.as_str()); + let func = dev.get_or_load_func(&kernel_name, kernels::AFFINE)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el) }?; + let params = ( + el, + dims.len(), + &ds, + src, + &out, + T::from_f64(self.0), + T::from_f64(self.1), + ); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} + fn slice_src_and_dst<'a, T>( src: &'a CudaSlice, src_l: &Layout, @@ -332,14 +398,8 @@ fn gemm_config( } impl CudaStorage { - pub fn try_clone(&self) -> Result { - let slice = match &self.slice { - CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?), - CudaStorageSlice::BF16(slice) => CudaStorageSlice::BF16(slice.try_clone()?), - CudaStorageSlice::F16(slice) => CudaStorageSlice::F16(slice.try_clone()?), - CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?), - CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?), - }; + pub fn try_clone(&self, layout: &Layout) -> Result { + let slice = Clone.map(&self.slice, self.device(), layout)?; let device = self.device.clone(); Ok(Self { slice, device }) } @@ -420,81 +480,8 @@ impl CudaStorage { } pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { - let shape = layout.shape(); - let dims = shape.dims(); - let el_count = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(el_count as u32); - let dev = self.device(); - let ds = dev.htod_copy([dims, layout.stride()].concat())?; - let slice = match &self.slice { - CudaStorageSlice::U32(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out, mul as u32, add as u32); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::U32(out) - } - CudaStorageSlice::BF16(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = ( - el_count, - dims.len(), - &ds, - arg, - &out, - bf16::from_f64(mul), - bf16::from_f64(add), - ); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::BF16(out) - } - CudaStorageSlice::F16(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = ( - el_count, - dims.len(), - &ds, - arg, - &out, - f16::from_f64(mul), - f16::from_f64(add), - ); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F16(out) - } - CudaStorageSlice::F32(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out, mul as f32, add as f32); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F32(out) - } - CudaStorageSlice::F64(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out, mul, add); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F64(out) - } - }; - let device = dev.clone(); + let device = self.device().clone(); + let slice = Affine(mul, add).map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } -- cgit v1.2.3 From d3c7b0d16812a7cb7d6266c42e7a57857dfccb86 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 29 Jun 2023 09:27:07 +0100 Subject: Use Map1 for sum. --- candle-core/src/cuda_backend.rs | 113 +++++++++++++++------------------------- 1 file changed, 43 insertions(+), 70 deletions(-) (limited to 'candle-core/src/cuda_backend.rs') diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 7dfbb468..94abd37a 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1,7 +1,9 @@ use crate::{CpuStorage, DType, Layout, Shape, WithDType}; use candle_kernels as kernels; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; -use cudarc::driver::{CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig}; +use cudarc::driver::{ + CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits, +}; use half::{bf16, f16}; use std::sync::Arc; @@ -244,7 +246,7 @@ enum CudaStorageSlice { } trait Map1 { - fn f( + fn f( &self, src: &CudaSlice, dev: &CudaDevice, @@ -276,7 +278,6 @@ impl Map1 for Clone { } struct Affine(f64, f64); - impl Map1 for Affine { fn f( &self, @@ -309,6 +310,43 @@ impl Map1 for Affine { } } +struct Sum<'a>(&'a [usize]); +impl<'a> Map1 for Sum<'a> { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let shape = layout.shape(); + let src_dims = shape.dims(); + let el = shape.elem_count(); + let mut dst_el = el; + for &sum_dim in self.0.iter() { + dst_el /= src_dims[sum_dim]; + } + let mut sum_dims = self.0.to_vec(); + // Sort the sum_dims as they have to be processed from left to right when converting the + // indexes. + sum_dims.sort(); + let sum_dims_l: Vec = sum_dims.iter().map(|&d| src_dims[d]).collect(); + let sum_dims_s: Vec = sum_dims + .iter() + .map(|&d| src_dims[d + 1..].iter().product::()) + .collect(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?; + let src = &src.slice(layout.start_offset()..); + let kernel_name = format!("sum_{}", T::DTYPE.as_str()); + let func = dev.get_or_load_func(&kernel_name, kernels::REDUCE)?; + let out = dev.alloc_zeros::(dst_el)?; + let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} + fn slice_src_and_dst<'a, T>( src: &'a CudaSlice, src_l: &Layout, @@ -486,73 +524,8 @@ impl CudaStorage { } pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result { - let shape = layout.shape(); - let src_dims = shape.dims(); - let el = shape.elem_count(); - let mut dst_el = el; - for &sum_dim in sum_dims.iter() { - dst_el /= src_dims[sum_dim]; - } - let mut sum_dims = sum_dims.to_vec(); - // Sort the sum_dims as they have to be processed from left to right when converting the - // indexes. - sum_dims.sort(); - let sum_dims_l: Vec = sum_dims.iter().map(|&d| src_dims[d]).collect(); - let sum_dims_s: Vec = sum_dims - .iter() - .map(|&d| src_dims[d + 1..].iter().product::()) - .collect(); - let cfg = LaunchConfig::for_num_elems(el as u32); - let dev = self.device(); - let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?; - let slice = match &self.slice { - CudaStorageSlice::U32(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("sum_u32", kernels::REDUCE)?; - let out = dev.alloc_zeros::(dst_el)?; - let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::U32(out) - } - CudaStorageSlice::BF16(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?; - let out = dev.alloc_zeros::(dst_el)?; - let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::BF16(out) - } - CudaStorageSlice::F16(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?; - let out = dev.alloc_zeros::(dst_el)?; - let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F16(out) - } - CudaStorageSlice::F32(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?; - let out = dev.alloc_zeros::(dst_el)?; - let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F32(out) - } - CudaStorageSlice::F64(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func("sum_f64", kernels::REDUCE)?; - let out = dev.alloc_zeros::(dst_el)?; - let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F64(out) - } - }; - let device = dev.clone(); + let device = self.device().clone(); + let slice = Sum(sum_dims).map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } -- cgit v1.2.3 From fff13dbb4e3951da410ef5f1251f252af74f4d0c Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 29 Jun 2023 09:29:59 +0100 Subject: Factorize the kernel naming scheme. --- candle-core/src/cuda_backend.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'candle-core/src/cuda_backend.rs') diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 94abd37a..6add6eb7 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -277,6 +277,11 @@ impl Map1 for Clone { } } +fn kernel_name(root: &str) -> String { + let dtype = T::DTYPE.as_str(); + format!("{root}_{dtype}") +} + struct Affine(f64, f64); impl Map1 for Affine { fn f( @@ -291,8 +296,7 @@ impl Map1 for Affine { let cfg = LaunchConfig::for_num_elems(el as u32); let ds = dev.htod_copy([dims, layout.stride()].concat())?; let src = &src.slice(layout.start_offset()..); - let kernel_name = format!("affine_{}", T::DTYPE.as_str()); - let func = dev.get_or_load_func(&kernel_name, kernels::AFFINE)?; + let func = dev.get_or_load_func(&kernel_name::("affine"), kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }?; let params = ( @@ -337,8 +341,7 @@ impl<'a> Map1 for Sum<'a> { let cfg = LaunchConfig::for_num_elems(el as u32); let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?; let src = &src.slice(layout.start_offset()..); - let kernel_name = format!("sum_{}", T::DTYPE.as_str()); - let func = dev.get_or_load_func(&kernel_name, kernels::REDUCE)?; + let func = dev.get_or_load_func(&kernel_name::("sum"), kernels::REDUCE)?; let out = dev.alloc_zeros::(dst_el)?; let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out); // SAFETY: ffi. -- cgit v1.2.3 From 8ad03a5fb674973064c4a2679140344e44f6d737 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 29 Jun 2023 09:37:38 +0100 Subject: Use Map1 on unary ops. --- candle-core/src/cuda_backend.rs | 77 +++++++++++++---------------------------- candle-core/src/op.rs | 28 +++++---------- 2 files changed, 33 insertions(+), 72 deletions(-) (limited to 'candle-core/src/cuda_backend.rs') diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 6add6eb7..dc0d51bf 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -350,6 +350,29 @@ impl<'a> Map1 for Sum<'a> { } } +impl Map1 for U { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result> { + let shape = layout.shape(); + let dims = shape.dims(); + let el_count = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el_count as u32); + let ds = dev.htod_copy([dims, layout.stride()].concat())?; + let src = &src.slice(layout.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el_count) }?; + let params = (el_count, dims.len(), &ds, src, &out); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} + fn slice_src_and_dst<'a, T>( src: &'a CudaSlice, src_l: &Layout, @@ -539,58 +562,8 @@ impl CudaStorage { } pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { - let shape = layout.shape(); - let dims = shape.dims(); - let el_count = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(el_count as u32); - let dev = &self.device; - let ds = dev.htod_copy([dims, layout.stride()].concat())?; - let slice = match &self.slice { - CudaStorageSlice::U32(_arg) => { - todo!("No unary kernels for u32"); - } - CudaStorageSlice::BF16(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::BF16(out) - } - CudaStorageSlice::F16(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F16(out) - } - CudaStorageSlice::F32(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F32(out) - } - CudaStorageSlice::F64(arg) => { - let arg = &arg.slice(layout.start_offset()..); - let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }?; - let params = (el_count, dims.len(), &ds, arg, &out); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F64(out) - } - }; - let device = dev.clone(); + let device = self.device().clone(); + let slice = U::V.map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 7b0e18fe..bbbd4bac 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -43,11 +43,8 @@ pub(crate) enum Op { pub(crate) trait UnaryOp { const NAME: &'static str; - const KERNEL_BF16: &'static str; - const KERNEL_F16: &'static str; - const KERNEL_F32: &'static str; - const KERNEL_F64: &'static str; - const KERNEL_U32: &'static str; + const KERNEL: &'static str; + const V: Self; fn bf16(v1: bf16) -> bf16; fn f16(v1: f16) -> f16; fn f32(v1: f32) -> f32; @@ -121,11 +118,8 @@ macro_rules! unary_op { ($op: ident, $name: literal, $a: ident, $e: expr) => { impl UnaryOp for $op { const NAME: &'static str = $name; - const KERNEL_BF16: &'static str = concat!("u", $name, "_bf16"); - const KERNEL_F16: &'static str = concat!("u", $name, "_f16"); - const KERNEL_F32: &'static str = concat!("u", $name, "_f32"); - const KERNEL_F64: &'static str = concat!("u", $name, "_f64"); - const KERNEL_U32: &'static str = concat!("u", $name, "_u32"); + const KERNEL: &'static str = concat!("u", $name); + const V: Self = $op; fn bf16($a: bf16) -> bf16 { $e } @@ -158,6 +152,7 @@ unary_op!(Sqrt, "sqrt", v, v.sqrt()); /// impl UnaryOp for Gelu { const NAME: &'static str = "gelu"; + const V: Self = Gelu; fn bf16(v: bf16) -> bf16 { bf16::from_f32_const(0.5) * v @@ -191,20 +186,13 @@ impl UnaryOp for Gelu { fn u32(_: u32) -> u32 { 0 } - const KERNEL_BF16: &'static str = "ugelu_bf16"; - const KERNEL_F16: &'static str = "ugelu_f16"; - const KERNEL_F32: &'static str = "ugelu_f32"; - const KERNEL_F64: &'static str = "ugelu_f64"; - const KERNEL_U32: &'static str = "ugelu_u32"; + const KERNEL: &'static str = "ugelu"; } impl UnaryOp for Relu { const NAME: &'static str = "relu"; - const KERNEL_BF16: &'static str = "urelu_bf16"; - const KERNEL_F16: &'static str = "urelu_f16"; - const KERNEL_F32: &'static str = "urelu_f32"; - const KERNEL_F64: &'static str = "urelu_f64"; - const KERNEL_U32: &'static str = "urelu_u32"; + const KERNEL: &'static str = "urelu"; + const V: Self = Relu; fn bf16(v: bf16) -> bf16 { v.max(bf16::ZERO) } -- cgit v1.2.3 From 367170da4527101c7e3aae8fbd7f0551fcddf5d0 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 29 Jun 2023 09:45:27 +0100 Subject: Also use Map1 for embedding. --- candle-core/src/cuda_backend.rs | 113 ++++++++++++++-------------------------- 1 file changed, 40 insertions(+), 73 deletions(-) (limited to 'candle-core/src/cuda_backend.rs') diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index dc0d51bf..7d06dd72 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -373,6 +373,44 @@ impl Map1 for U { } } +struct Embedding<'a>(&'a CudaStorage, &'a Layout); +impl<'a> Map1 for Embedding<'a> { + fn f( + &self, + rhs: &CudaSlice, + dev: &CudaDevice, + rhs_l: &Layout, + ) -> Result> { + let ids_l = &self.1; + let ids = match &self.0.slice { + CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..), + _ => Err(CudaError::UnexpectedDType { + msg: "embedding ids should be u32", + expected: DType::U32, + got: self.0.dtype(), + })?, + }; + let ids = &ids; + let shape = ids_l.shape(); + let (v_size, h_size) = rhs_l + .shape() + .r2() + .map_err(|e| CudaError::WrappedError(Box::new(e)))?; + let dims = shape.dims(); + let el = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let ds = dev.htod_copy([dims, ids_l.stride()].concat())?; + let rhs = &rhs.slice(rhs_l.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::("emb"), kernels::EMBEDDINGS)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el * h_size) }?; + let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} + fn slice_src_and_dst<'a, T>( src: &'a CudaSlice, src_l: &Layout, @@ -760,79 +798,8 @@ impl CudaStorage { } pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { - let ids = match &self.slice { - CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..), - _ => Err(CudaError::UnexpectedDType { - msg: "embedding ids should be u32", - expected: DType::U32, - got: self.dtype(), - })?, - }; - let ids = &ids; - let shape = layout.shape(); - let (v_size, h_size) = rhs_l - .shape() - .r2() - .map_err(|e| CudaError::WrappedError(Box::new(e)))?; - let dims = shape.dims(); - let el = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(el as u32); - let dev = self.device(); - let ds = dev.htod_copy([dims, layout.stride()].concat())?; - let slice = match &rhs.slice { - // The kernels below assume that rhs is contiguous. - CudaStorageSlice::U32(arg) => { - let arg = &arg.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el * h_size) }?; - let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::U32(out) - } - CudaStorageSlice::BF16(arg) => { - let arg = &arg.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el * h_size) }?; - let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::BF16(out) - } - CudaStorageSlice::F16(arg) => { - let arg = &arg.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el * h_size) }?; - let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F16(out) - } - CudaStorageSlice::F32(arg) => { - let arg = &arg.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el * h_size) }?; - let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F32(out) - } - CudaStorageSlice::F64(arg) => { - let arg = &arg.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el * h_size) }?; - let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F64(out) - } - }; - let device = dev.clone(); + let device = self.device().clone(); + let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?; Ok(Self { slice, device }) } -- cgit v1.2.3 From 83c7d660ca6268cbce4e573ec7d54a05f206b8f1 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 29 Jun 2023 10:05:06 +0100 Subject: Add Map2. --- candle-core/src/cuda_backend.rs | 156 +++++++++++++++++++--------------------- 1 file changed, 72 insertions(+), 84 deletions(-) (limited to 'candle-core/src/cuda_backend.rs') diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 7d06dd72..0e9c11c8 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -244,6 +244,7 @@ enum CudaStorageSlice { F32(CudaSlice), F64(CudaSlice), } +type S = CudaStorageSlice; trait Map1 { fn f( @@ -253,13 +254,36 @@ trait Map1 { layout: &Layout, ) -> Result>; - fn map(&self, s: &CudaStorageSlice, d: &CudaDevice, l: &Layout) -> Result { + fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result { let out = match s { - CudaStorageSlice::U32(s) => CudaStorageSlice::U32(self.f(s, d, l)?), - CudaStorageSlice::BF16(s) => CudaStorageSlice::BF16(self.f(s, d, l)?), - CudaStorageSlice::F16(s) => CudaStorageSlice::F16(self.f(s, d, l)?), - CudaStorageSlice::F32(s) => CudaStorageSlice::F32(self.f(s, d, l)?), - CudaStorageSlice::F64(s) => CudaStorageSlice::F64(self.f(s, d, l)?), + S::U32(s) => S::U32(self.f(s, d, l)?), + S::BF16(s) => S::BF16(self.f(s, d, l)?), + S::F16(s) => S::F16(self.f(s, d, l)?), + S::F32(s) => S::F32(self.f(s, d, l)?), + S::F64(s) => S::F64(self.f(s, d, l)?), + }; + Ok(out) + } +} + +trait Map2 { + fn f( + &self, + src1: &CudaSlice, + layout1: &Layout, + src2: &CudaSlice, + layout2: &Layout, + dev: &CudaDevice, + ) -> Result>; + + fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result { + let out = match (s1, s2) { + (S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?), + (S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?), + (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), + (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), + (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?), + _ => return Err(CudaError::InternalError("dtype mismatch in binary op")), }; Ok(out) } @@ -411,6 +435,44 @@ impl<'a> Map1 for Embedding<'a> { } } +struct WhereCond<'a>(&'a CudaStorage, &'a Layout); +impl<'a> Map2 for WhereCond<'a> { + fn f( + &self, + t: &CudaSlice, + layout_t: &Layout, + f: &CudaSlice, + layout_f: &Layout, + dev: &CudaDevice, + ) -> Result> { + let ids_l = &self.1; + let ids = match &self.0.slice { + CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..), + _ => Err(CudaError::UnexpectedDType { + msg: "where conditions should be u32", + expected: DType::U32, + got: self.0.dtype(), + })?, + }; + let ids = &ids; + let shape = ids_l.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let ds = + dev.htod_copy([dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())?; + let t = &t.slice(layout_t.start_offset()..); + let f = &f.slice(layout_f.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::("where"), kernels::TERNARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el) }?; + let params = (el, dims.len(), &ds, ids, t, f, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} + fn slice_src_and_dst<'a, T>( src: &'a CudaSlice, src_l: &Layout, @@ -714,86 +776,12 @@ impl CudaStorage { &self, layout: &Layout, t: &Self, - layout_t: &Layout, + t_l: &Layout, f: &Self, - layout_f: &Layout, + f_l: &Layout, ) -> Result { - let ids = match &self.slice { - CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..), - _ => Err(CudaError::UnexpectedDType { - msg: "where conditions should be u32", - expected: DType::U32, - got: self.dtype(), - })?, - }; - let ids = &ids; - let shape = layout.shape(); - let dims = shape.dims(); - let el = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(el as u32); - let dev = self.device(); - let ds = - dev.htod_copy([dims, layout.stride(), layout_t.stride(), layout_f.stride()].concat())?; - let slice = match (&t.slice, &f.slice) { - (CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => { - let t = &t.slice(layout_t.start_offset()..); - let f = &f.slice(layout_f.start_offset()..); - let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el) }?; - let params = (el, dims.len(), &ds, ids, t, f, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::BF16(out) - } - (CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => { - let t = &t.slice(layout_t.start_offset()..); - let f = &f.slice(layout_f.start_offset()..); - let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el) }?; - let params = (el, dims.len(), &ds, ids, t, f, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F16(out) - } - (CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => { - let t = &t.slice(layout_t.start_offset()..); - let f = &f.slice(layout_f.start_offset()..); - let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el) }?; - let params = (el, dims.len(), &ds, ids, t, f, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F32(out) - } - (CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => { - let t = &t.slice(layout_t.start_offset()..); - let f = &f.slice(layout_f.start_offset()..); - // SAFETY: Set later by running the kernel. - let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?; - let out = unsafe { dev.alloc::(el) }?; - let params = (el, dims.len(), &ds, ids, t, f, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F64(out) - } - (CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => { - let t = &t.slice(layout_t.start_offset()..); - let f = &f.slice(layout_f.start_offset()..); - // SAFETY: Set later by running the kernel. - let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?; - let out = unsafe { dev.alloc::(el) }?; - let params = (el, dims.len(), &ds, ids, t, f, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::U32(out) - } - // The dtypes should have been checked at this point so this is an internal error. - _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, - }; - let device = dev.clone(); + let device = self.device().clone(); + let slice = WhereCond(self, layout).map(&t.slice, t_l, &f.slice, f_l, &device)?; Ok(Self { slice, device }) } -- cgit v1.2.3 From c9c468e1aaf0ce071b145f15aba830e9600fd6e6 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 29 Jun 2023 10:09:15 +0100 Subject: Use Map2 for binary ops. --- candle-core/src/cuda_backend.rs | 94 ++++++++++++----------------------------- candle-core/src/op.rs | 14 ++---- 2 files changed, 32 insertions(+), 76 deletions(-) (limited to 'candle-core/src/cuda_backend.rs') diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 0e9c11c8..40b7e67f 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -473,6 +473,32 @@ impl<'a> Map2 for WhereCond<'a> { } } +impl Map2 for U { + fn f( + &self, + lhs: &CudaSlice, + lhs_l: &Layout, + rhs: &CudaSlice, + rhs_l: &Layout, + dev: &CudaDevice, + ) -> Result> { + let shape = lhs_l.shape(); + let dims = shape.dims(); + let elem_count = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); + let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?; + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::BINARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(elem_count) }?; + let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); + // SAFETY: ffi + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} + fn slice_src_and_dst<'a, T>( src: &'a CudaSlice, src_l: &Layout, @@ -673,72 +699,8 @@ impl CudaStorage { lhs_l: &Layout, rhs_l: &Layout, ) -> Result { - let shape = lhs_l.shape(); - let dims = shape.dims(); - let elem_count = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(elem_count as u32); - let dev = self.device(); - let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?; - let slice = match (&self.slice, &rhs.slice) { - (CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => { - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(elem_count) }?; - let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::BF16(out) - } - (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(elem_count) }?; - let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F16(out) - } - (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(elem_count) }?; - let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F32(out) - } - (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - // SAFETY: Set later by running the kernel. - let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?; - let out = unsafe { dev.alloc::(elem_count) }?; - let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F64(out) - } - (CudaStorageSlice::U32(lhs), CudaStorageSlice::U32(rhs)) => { - let lhs = &lhs.slice(lhs_l.start_offset()..); - let rhs = &rhs.slice(rhs_l.start_offset()..); - // SAFETY: Set later by running the kernel. - let func = dev.get_or_load_func(B::KERNEL_U32, kernels::BINARY)?; - let out = unsafe { dev.alloc::(elem_count) }?; - let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); - // SAFETY: ffi - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::U32(out) - } - // The dtypes should have been checked at this point so this is an internal error. - _ => return Err(CudaError::InternalError("dtype mismatch in binary op")), - }; - let device = dev.clone(); + let device = self.device().clone(); + let slice = B::V.map(&self.slice, lhs_l, &rhs.slice, rhs_l, &device)?; Ok(Self { slice, device }) } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index bbbd4bac..db6ef87f 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -54,11 +54,8 @@ pub(crate) trait UnaryOp { pub(crate) trait BinaryOp { const NAME: &'static str; - const KERNEL_BF16: &'static str; - const KERNEL_F16: &'static str; - const KERNEL_F32: &'static str; - const KERNEL_F64: &'static str; - const KERNEL_U32: &'static str; + const KERNEL: &'static str; + const V: Self; fn bf16(v1: bf16, v2: bf16) -> bf16; fn f16(v1: f16, v2: f16) -> f16; fn f32(v1: f32, v2: f32) -> f32; @@ -85,11 +82,8 @@ macro_rules! bin_op { ($op:ident, $name: literal, $e: expr) => { impl BinaryOp for $op { const NAME: &'static str = $name; - const KERNEL_BF16: &'static str = concat!("b", $name, "_bf16"); - const KERNEL_F16: &'static str = concat!("b", $name, "_f16"); - const KERNEL_F32: &'static str = concat!("b", $name, "_f32"); - const KERNEL_F64: &'static str = concat!("b", $name, "_f64"); - const KERNEL_U32: &'static str = concat!("b", $name, "_u32"); + const KERNEL: &'static str = concat!("b", $name); + const V: Self = $op; fn bf16(v1: bf16, v2: bf16) -> bf16 { $e(v1, v2) } -- cgit v1.2.3