summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/examples/llama/main.rs2
-rw-r--r--candle-core/src/cuda_backend.rs695
-rw-r--r--candle-core/src/dummy_cuda_backend.rs2
-rw-r--r--candle-core/src/op.rs42
-rw-r--r--candle-core/src/storage.rs4
-rw-r--r--candle-core/src/tensor.rs2
6 files changed, 292 insertions, 455 deletions
diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs
index eb681f4b..3fc893e3 100644
--- a/candle-core/examples/llama/main.rs
+++ b/candle-core/examples/llama/main.rs
@@ -487,6 +487,7 @@ fn main() -> Result<()> {
let mut rng = thread_rng();
let start_gen = std::time::Instant::now();
for index in 0..args.sample_len {
+ let start_gen = std::time::Instant::now();
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
let input = Tensor::new(ctxt, &device)?;
let logits = llama.forward(&input, &freqs_cis)?;
@@ -496,6 +497,7 @@ fn main() -> Result<()> {
let next_token = distr.sample(&mut rng) as u32;
tokens.push(next_token);
new_tokens.push(next_token);
+ println!("> {:?}", start_gen.elapsed());
println!(
"{} token: {} '{}'",
index + 1,
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 9d9a5f99..40b7e67f 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1,7 +1,9 @@
-use crate::{CpuStorage, DType, Layout, Shape};
+use crate::{CpuStorage, DType, Layout, Shape, WithDType};
use candle_kernels as kernels;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
-use cudarc::driver::{CudaFunction, CudaSlice, DeviceSlice, LaunchAsync, LaunchConfig};
+use cudarc::driver::{
+ CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
+};
use half::{bf16, f16};
use std::sync::Arc;
@@ -242,6 +244,260 @@ enum CudaStorageSlice {
F32(CudaSlice<f32>),
F64(CudaSlice<f64>),
}
+type S = CudaStorageSlice;
+
+trait Map1 {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ src: &CudaSlice<T>,
+ dev: &CudaDevice,
+ layout: &Layout,
+ ) -> Result<CudaSlice<T>>;
+
+ fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
+ let out = match s {
+ S::U32(s) => S::U32(self.f(s, d, l)?),
+ S::BF16(s) => S::BF16(self.f(s, d, l)?),
+ S::F16(s) => S::F16(self.f(s, d, l)?),
+ S::F32(s) => S::F32(self.f(s, d, l)?),
+ S::F64(s) => S::F64(self.f(s, d, l)?),
+ };
+ Ok(out)
+ }
+}
+
+trait Map2 {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ src1: &CudaSlice<T>,
+ layout1: &Layout,
+ src2: &CudaSlice<T>,
+ layout2: &Layout,
+ dev: &CudaDevice,
+ ) -> Result<CudaSlice<T>>;
+
+ fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
+ let out = match (s1, s2) {
+ (S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
+ (S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
+ (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
+ (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
+ (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
+ _ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
+ };
+ Ok(out)
+ }
+}
+
+struct Clone;
+impl Map1 for Clone {
+ fn f<T: DeviceRepr>(
+ &self,
+ s: &CudaSlice<T>,
+ _: &CudaDevice,
+ _: &Layout,
+ ) -> Result<CudaSlice<T>> {
+ Ok(s.try_clone()?)
+ }
+}
+
+fn kernel_name<T: WithDType>(root: &str) -> String {
+ let dtype = T::DTYPE.as_str();
+ format!("{root}_{dtype}")
+}
+
+struct Affine(f64, f64);
+impl Map1 for Affine {
+ fn f<T: DeviceRepr + WithDType>(
+ &self,
+ src: &CudaSlice<T>,
+ dev: &CudaDevice,
+ layout: &Layout,
+ ) -> Result<CudaSlice<T>> {
+ let shape = layout.shape();
+ let dims = shape.dims();
+ let el = shape.elem_count();
+ let cfg = LaunchConfig::for_num_elems(el as u32);
+ let ds = dev.htod_copy([dims, layout.stride()].concat())?;
+ let src = &src.slice(layout.start_offset()..);
+ let func = dev.get_or_load_func(&kernel_name::<T>("affine"), kernels::AFFINE)?;
+ // SAFETY: Set later by running the kernel.
+ let out = unsafe { dev.alloc::<T>(el) }?;
+ let params = (
+ el,
+ dims.len(),
+ &ds,
+ src,
+ &out,
+ T::from_f64(self.0),
+ T::from_f64(self.1),
+ );
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }?;
+ Ok(out)
+ }
+}
+
+struct Sum<'a>(&'a [usize]);
+impl<'a> Map1 for Sum<'a> {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ src: &CudaSlice<T>,
+ dev: &CudaDevice,
+ layout: &Layout,
+ ) -> Result<CudaSlice<T>> {
+ let shape = layout.shape();
+ let src_dims = shape.dims();
+ let el = shape.elem_count();
+ let mut dst_el = el;
+ for &sum_dim in self.0.iter() {
+ dst_el /= src_dims[sum_dim];
+ }
+ let mut sum_dims = self.0.to_vec();
+ // Sort the sum_dims as they have to be processed from left to right when converting the
+ // indexes.
+ sum_dims.sort();
+ let sum_dims_l: Vec<usize> = sum_dims.iter().map(|&d| src_dims[d]).collect();
+ let sum_dims_s: Vec<usize> = sum_dims
+ .iter()
+ .map(|&d| src_dims[d + 1..].iter().product::<usize>())
+ .collect();
+ let cfg = LaunchConfig::for_num_elems(el as u32);
+ let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?;
+ let src = &src.slice(layout.start_offset()..);
+ let func = dev.get_or_load_func(&kernel_name::<T>("sum"), kernels::REDUCE)?;
+ let out = dev.alloc_zeros::<T>(dst_el)?;
+ let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out);
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }?;
+ Ok(out)
+ }
+}
+
+impl<U: crate::op::UnaryOp> Map1 for U {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ src: &CudaSlice<T>,
+ dev: &CudaDevice,
+ layout: &Layout,
+ ) -> Result<CudaSlice<T>> {
+ let shape = layout.shape();
+ let dims = shape.dims();
+ let el_count = shape.elem_count();
+ let cfg = LaunchConfig::for_num_elems(el_count as u32);
+ let ds = dev.htod_copy([dims, layout.stride()].concat())?;
+ let src = &src.slice(layout.start_offset()..);
+ let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::UNARY)?;
+ // SAFETY: Set later by running the kernel.
+ let out = unsafe { dev.alloc::<T>(el_count) }?;
+ let params = (el_count, dims.len(), &ds, src, &out);
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }?;
+ Ok(out)
+ }
+}
+
+struct Embedding<'a>(&'a CudaStorage, &'a Layout);
+impl<'a> Map1 for Embedding<'a> {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ rhs: &CudaSlice<T>,
+ dev: &CudaDevice,
+ rhs_l: &Layout,
+ ) -> Result<CudaSlice<T>> {
+ let ids_l = &self.1;
+ let ids = match &self.0.slice {
+ CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
+ _ => Err(CudaError::UnexpectedDType {
+ msg: "embedding ids should be u32",
+ expected: DType::U32,
+ got: self.0.dtype(),
+ })?,
+ };
+ let ids = &ids;
+ let shape = ids_l.shape();
+ let (v_size, h_size) = rhs_l
+ .shape()
+ .r2()
+ .map_err(|e| CudaError::WrappedError(Box::new(e)))?;
+ let dims = shape.dims();
+ let el = shape.elem_count();
+ let cfg = LaunchConfig::for_num_elems(el as u32);
+ let ds = dev.htod_copy([dims, ids_l.stride()].concat())?;
+ let rhs = &rhs.slice(rhs_l.start_offset()..);
+ let func = dev.get_or_load_func(&kernel_name::<T>("emb"), kernels::EMBEDDINGS)?;
+ // SAFETY: Set later by running the kernel.
+ let out = unsafe { dev.alloc::<T>(el * h_size) }?;
+ let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size);
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }?;
+ Ok(out)
+ }
+}
+
+struct WhereCond<'a>(&'a CudaStorage, &'a Layout);
+impl<'a> Map2 for WhereCond<'a> {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ t: &CudaSlice<T>,
+ layout_t: &Layout,
+ f: &CudaSlice<T>,
+ layout_f: &Layout,
+ dev: &CudaDevice,
+ ) -> Result<CudaSlice<T>> {
+ let ids_l = &self.1;
+ let ids = match &self.0.slice {
+ CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
+ _ => Err(CudaError::UnexpectedDType {
+ msg: "where conditions should be u32",
+ expected: DType::U32,
+ got: self.0.dtype(),
+ })?,
+ };
+ let ids = &ids;
+ let shape = ids_l.shape();
+ let dims = shape.dims();
+ let el = shape.elem_count();
+ let cfg = LaunchConfig::for_num_elems(el as u32);
+ let ds =
+ dev.htod_copy([dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())?;
+ let t = &t.slice(layout_t.start_offset()..);
+ let f = &f.slice(layout_f.start_offset()..);
+ let func = dev.get_or_load_func(&kernel_name::<T>("where"), kernels::TERNARY)?;
+ // SAFETY: Set later by running the kernel.
+ let out = unsafe { dev.alloc::<T>(el) }?;
+ let params = (el, dims.len(), &ds, ids, t, f, &out);
+ // SAFETY: ffi
+ unsafe { func.launch(cfg, params) }?;
+ Ok(out)
+ }
+}
+
+impl<U: crate::op::BinaryOp> Map2 for U {
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ lhs: &CudaSlice<T>,
+ lhs_l: &Layout,
+ rhs: &CudaSlice<T>,
+ rhs_l: &Layout,
+ dev: &CudaDevice,
+ ) -> Result<CudaSlice<T>> {
+ let shape = lhs_l.shape();
+ let dims = shape.dims();
+ let elem_count = shape.elem_count();
+ let cfg = LaunchConfig::for_num_elems(elem_count as u32);
+ let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?;
+ let lhs = &lhs.slice(lhs_l.start_offset()..);
+ let rhs = &rhs.slice(rhs_l.start_offset()..);
+ let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::BINARY)?;
+ // SAFETY: Set later by running the kernel.
+ let out = unsafe { dev.alloc::<T>(elem_count) }?;
+ let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
+ // SAFETY: ffi
+ unsafe { func.launch(cfg, params) }?;
+ Ok(out)
+ }
+}
fn slice_src_and_dst<'a, T>(
src: &'a CudaSlice<T>,
@@ -332,14 +588,8 @@ fn gemm_config<T>(
}
impl CudaStorage {
- pub fn try_clone(&self) -> Result<Self> {
- let slice = match &self.slice {
- CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?),
- CudaStorageSlice::BF16(slice) => CudaStorageSlice::BF16(slice.try_clone()?),
- CudaStorageSlice::F16(slice) => CudaStorageSlice::F16(slice.try_clone()?),
- CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?),
- CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?),
- };
+ pub fn try_clone(&self, layout: &Layout) -> Result<Self> {
+ let slice = Clone.map(&self.slice, self.device(), layout)?;
let device = self.device.clone();
Ok(Self { slice, device })
}
@@ -420,152 +670,14 @@ impl CudaStorage {
}
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
- let shape = layout.shape();
- let dims = shape.dims();
- let el_count = shape.elem_count();
- let cfg = LaunchConfig::for_num_elems(el_count as u32);
- let dev = self.device();
- let ds = dev.htod_copy([dims, layout.stride()].concat())?;
- let slice = match &self.slice {
- CudaStorageSlice::U32(arg) => {
- let arg = &arg.slice(layout.start_offset()..);
- let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<u32>(el_count) }?;
- let params = (el_count, dims.len(), &ds, arg, &out, mul as u32, add as u32);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::U32(out)
- }
- CudaStorageSlice::BF16(arg) => {
- let arg = &arg.slice(layout.start_offset()..);
- let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<bf16>(el_count) }?;
- let params = (
- el_count,
- dims.len(),
- &ds,
- arg,
- &out,
- bf16::from_f64(mul),
- bf16::from_f64(add),
- );
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::BF16(out)
- }
- CudaStorageSlice::F16(arg) => {
- let arg = &arg.slice(layout.start_offset()..);
- let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<f16>(el_count) }?;
- let params = (
- el_count,
- dims.len(),
- &ds,
- arg,
- &out,
- f16::from_f64(mul),
- f16::from_f64(add),
- );
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F16(out)
- }
- CudaStorageSlice::F32(arg) => {
- let arg = &arg.slice(layout.start_offset()..);
- let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<f32>(el_count) }?;
- let params = (el_count, dims.len(), &ds, arg, &out, mul as f32, add as f32);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F32(out)
- }
- CudaStorageSlice::F64(arg) => {
- let arg = &arg.slice(layout.start_offset()..);
- let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<f64>(el_count) }?;
- let params = (el_count, dims.len(), &ds, arg, &out, mul, add);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F64(out)
- }
- };
- let device = dev.clone();
+ let device = self.device().clone();
+ let slice = Affine(mul, add).map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
- let shape = layout.shape();
- let src_dims = shape.dims();
- let el = shape.elem_count();
- let mut dst_el = el;
- for &sum_dim in sum_dims.iter() {
- dst_el /= src_dims[sum_dim];
- }
- let mut sum_dims = sum_dims.to_vec();
- // Sort the sum_dims as they have to be processed from left to right when converting the
- // indexes.
- sum_dims.sort();
- let sum_dims_l: Vec<usize> = sum_dims.iter().map(|&d| src_dims[d]).collect();
- let sum_dims_s: Vec<usize> = sum_dims
- .iter()
- .map(|&d| src_dims[d + 1..].iter().product::<usize>())
- .collect();
- let cfg = LaunchConfig::for_num_elems(el as u32);
- let dev = self.device();
- let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?;
- let slice = match &self.slice {
- CudaStorageSlice::U32(arg) => {
- let arg = &arg.slice(layout.start_offset()..);
- let func = dev.get_or_load_func("sum_u32", kernels::REDUCE)?;
- let out = dev.alloc_zeros::<u32>(dst_el)?;
- let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::U32(out)
- }
- CudaStorageSlice::BF16(arg) => {
- let arg = &arg.slice(layout.start_offset()..);
- let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?;
- let out = dev.alloc_zeros::<bf16>(dst_el)?;
- let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::BF16(out)
- }
- CudaStorageSlice::F16(arg) => {
- let arg = &arg.slice(layout.start_offset()..);
- let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?;
- let out = dev.alloc_zeros::<f16>(dst_el)?;
- let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F16(out)
- }
- CudaStorageSlice::F32(arg) => {
- let arg = &arg.slice(layout.start_offset()..);
- let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?;
- let out = dev.alloc_zeros::<f32>(dst_el)?;
- let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F32(out)
- }
- CudaStorageSlice::F64(arg) => {
- let arg = &arg.slice(layout.start_offset()..);
- let func = dev.get_or_load_func("sum_f64", kernels::REDUCE)?;
- let out = dev.alloc_zeros::<f64>(dst_el)?;
- let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F64(out)
- }
- };
- let device = dev.clone();
+ let device = self.device().clone();
+ let slice = Sum(sum_dims).map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
@@ -576,58 +688,8 @@ impl CudaStorage {
}
pub(crate) fn unary_impl<U: crate::op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
- let shape = layout.shape();
- let dims = shape.dims();
- let el_count = shape.elem_count();
- let cfg = LaunchConfig::for_num_elems(el_count as u32);
- let dev = &self.device;
- let ds = dev.htod_copy([dims, layout.stride()].concat())?;
- let slice = match &self.slice {
- CudaStorageSlice::U32(_arg) => {
- todo!("No unary kernels for u32");
- }
- CudaStorageSlice::BF16(arg) => {
- let arg = &arg.slice(layout.start_offset()..);
- let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<bf16>(el_count) }?;
- let params = (el_count, dims.len(), &ds, arg, &out);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::BF16(out)
- }
- CudaStorageSlice::F16(arg) => {
- let arg = &arg.slice(layout.start_offset()..);
- let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<f16>(el_count) }?;
- let params = (el_count, dims.len(), &ds, arg, &out);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F16(out)
- }
- CudaStorageSlice::F32(arg) => {
- let arg = &arg.slice(layout.start_offset()..);
- let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<f32>(el_count) }?;
- let params = (el_count, dims.len(), &ds, arg, &out);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F32(out)
- }
- CudaStorageSlice::F64(arg) => {
- let arg = &arg.slice(layout.start_offset()..);
- let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<f64>(el_count) }?;
- let params = (el_count, dims.len(), &ds, arg, &out);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F64(out)
- }
- };
- let device = dev.clone();
+ let device = self.device().clone();
+ let slice = U::V.map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
@@ -637,72 +699,8 @@ impl CudaStorage {
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
- let shape = lhs_l.shape();
- let dims = shape.dims();
- let elem_count = shape.elem_count();
- let cfg = LaunchConfig::for_num_elems(elem_count as u32);
- let dev = self.device();
- let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?;
- let slice = match (&self.slice, &rhs.slice) {
- (CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => {
- let lhs = &lhs.slice(lhs_l.start_offset()..);
- let rhs = &rhs.slice(rhs_l.start_offset()..);
- let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<bf16>(elem_count) }?;
- let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
- // SAFETY: ffi
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::BF16(out)
- }
- (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
- let lhs = &lhs.slice(lhs_l.start_offset()..);
- let rhs = &rhs.slice(rhs_l.start_offset()..);
- let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<f16>(elem_count) }?;
- let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
- // SAFETY: ffi
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F16(out)
- }
- (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
- let lhs = &lhs.slice(lhs_l.start_offset()..);
- let rhs = &rhs.slice(rhs_l.start_offset()..);
- let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<f32>(elem_count) }?;
- let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
- // SAFETY: ffi
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F32(out)
- }
- (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
- let lhs = &lhs.slice(lhs_l.start_offset()..);
- let rhs = &rhs.slice(rhs_l.start_offset()..);
- // SAFETY: Set later by running the kernel.
- let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?;
- let out = unsafe { dev.alloc::<f64>(elem_count) }?;
- let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
- // SAFETY: ffi
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F64(out)
- }
- (CudaStorageSlice::U32(lhs), CudaStorageSlice::U32(rhs)) => {
- let lhs = &lhs.slice(lhs_l.start_offset()..);
- let rhs = &rhs.slice(rhs_l.start_offset()..);
- // SAFETY: Set later by running the kernel.
- let func = dev.get_or_load_func(B::KERNEL_U32, kernels::BINARY)?;
- let out = unsafe { dev.alloc::<u32>(elem_count) }?;
- let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
- // SAFETY: ffi
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::U32(out)
- }
- // The dtypes should have been checked at this point so this is an internal error.
- _ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
- };
- let device = dev.clone();
+ let device = self.device().clone();
+ let slice = B::V.map(&self.slice, lhs_l, &rhs.slice, rhs_l, &device)?;
Ok(Self { slice, device })
}
@@ -740,163 +738,18 @@ impl CudaStorage {
&self,
layout: &Layout,
t: &Self,
- layout_t: &Layout,
+ t_l: &Layout,
f: &Self,
- layout_f: &Layout,
+ f_l: &Layout,
) -> Result<Self> {
- let ids = match &self.slice {
- CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..),
- _ => Err(CudaError::UnexpectedDType {
- msg: "where conditions should be u32",
- expected: DType::U32,
- got: self.dtype(),
- })?,
- };
- let ids = &ids;
- let shape = layout.shape();
- let dims = shape.dims();
- let el = shape.elem_count();
- let cfg = LaunchConfig::for_num_elems(el as u32);
- let dev = self.device();
- let ds =
- dev.htod_copy([dims, layout.stride(), layout_t.stride(), layout_f.stride()].concat())?;
- let slice = match (&t.slice, &f.slice) {
- (CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => {
- let t = &t.slice(layout_t.start_offset()..);
- let f = &f.slice(layout_f.start_offset()..);
- let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<bf16>(el) }?;
- let params = (el, dims.len(), &ds, ids, t, f, &out);
- // SAFETY: ffi
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::BF16(out)
- }
- (CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => {
- let t = &t.slice(layout_t.start_offset()..);
- let f = &f.slice(layout_f.start_offset()..);
- let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<f16>(el) }?;
- let params = (el, dims.len(), &ds, ids, t, f, &out);
- // SAFETY: ffi
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F16(out)
- }
- (CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => {
- let t = &t.slice(layout_t.start_offset()..);
- let f = &f.slice(layout_f.start_offset()..);
- let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<f32>(el) }?;
- let params = (el, dims.len(), &ds, ids, t, f, &out);
- // SAFETY: ffi
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F32(out)
- }
- (CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => {
- let t = &t.slice(layout_t.start_offset()..);
- let f = &f.slice(layout_f.start_offset()..);
- // SAFETY: Set later by running the kernel.
- let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?;
- let out = unsafe { dev.alloc::<f64>(el) }?;
- let params = (el, dims.len(), &ds, ids, t, f, &out);
- // SAFETY: ffi
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F64(out)
- }
- (CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => {
- let t = &t.slice(layout_t.start_offset()..);
- let f = &f.slice(layout_f.start_offset()..);
- // SAFETY: Set later by running the kernel.
- let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?;
- let out = unsafe { dev.alloc::<u32>(el) }?;
- let params = (el, dims.len(), &ds, ids, t, f, &out);
- // SAFETY: ffi
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::U32(out)
- }
- // The dtypes should have been checked at this point so this is an internal error.
- _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
- };
- let device = dev.clone();
+ let device = self.device().clone();
+ let slice = WhereCond(self, layout).map(&t.slice, t_l, &f.slice, f_l, &device)?;
Ok(Self { slice, device })
}
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
- let ids = match &self.slice {
- CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..),
- _ => Err(CudaError::UnexpectedDType {
- msg: "embedding ids should be u32",
- expected: DType::U32,
- got: self.dtype(),
- })?,
- };
- let ids = &ids;
- let shape = layout.shape();
- let (v_size, h_size) = rhs_l
- .shape()
- .r2()
- .map_err(|e| CudaError::WrappedError(Box::new(e)))?;
- let dims = shape.dims();
- let el = shape.elem_count();
- let cfg = LaunchConfig::for_num_elems(el as u32);
- let dev = self.device();
- let ds = dev.htod_copy([dims, layout.stride()].concat())?;
- let slice = match &rhs.slice {
- // The kernels below assume that rhs is contiguous.
- CudaStorageSlice::U32(arg) => {
- let arg = &arg.slice(rhs_l.start_offset()..);
- let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<u32>(el * h_size) }?;
- let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::U32(out)
- }
- CudaStorageSlice::BF16(arg) => {
- let arg = &arg.slice(rhs_l.start_offset()..);
- let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<bf16>(el * h_size) }?;
- let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::BF16(out)
- }
- CudaStorageSlice::F16(arg) => {
- let arg = &arg.slice(rhs_l.start_offset()..);
- let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<f16>(el * h_size) }?;
- let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F16(out)
- }
- CudaStorageSlice::F32(arg) => {
- let arg = &arg.slice(rhs_l.start_offset()..);
- let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<f32>(el * h_size) }?;
- let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F32(out)
- }
- CudaStorageSlice::F64(arg) => {
- let arg = &arg.slice(rhs_l.start_offset()..);
- let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?;
- // SAFETY: Set later by running the kernel.
- let out = unsafe { dev.alloc::<f64>(el * h_size) }?;
- let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
- // SAFETY: ffi.
- unsafe { func.launch(cfg, params) }?;
- CudaStorageSlice::F64(out)
- }
- };
- let device = dev.clone();
+ let device = self.device().clone();
+ let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;
Ok(Self { slice, device })
}
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index 8193b1af..b025eeab 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -44,7 +44,7 @@ impl CudaDevice {
pub struct CudaStorage;
impl CudaStorage {
- pub fn try_clone(&self) -> Result<Self> {
+ pub fn try_clone(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index 7b0e18fe..db6ef87f 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -43,11 +43,8 @@ pub(crate) enum Op {
pub(crate) trait UnaryOp {
const NAME: &'static str;
- const KERNEL_BF16: &'static str;
- const KERNEL_F16: &'static str;
- const KERNEL_F32: &'static str;
- const KERNEL_F64: &'static str;
- const KERNEL_U32: &'static str;
+ const KERNEL: &'static str;
+ const V: Self;
fn bf16(v1: bf16) -> bf16;
fn f16(v1: f16) -> f16;
fn f32(v1: f32) -> f32;
@@ -57,11 +54,8 @@ pub(crate) trait UnaryOp {
pub(crate) trait BinaryOp {
const NAME: &'static str;
- const KERNEL_BF16: &'static str;
- const KERNEL_F16: &'static str;
- const KERNEL_F32: &'static str;
- const KERNEL_F64: &'static str;
- const KERNEL_U32: &'static str;
+ const KERNEL: &'static str;
+ const V: Self;
fn bf16(v1: bf16, v2: bf16) -> bf16;
fn f16(v1: f16, v2: f16) -> f16;
fn f32(v1: f32, v2: f32) -> f32;
@@ -88,11 +82,8 @@ macro_rules! bin_op {
($op:ident, $name: literal, $e: expr) => {
impl BinaryOp for $op {
const NAME: &'static str = $name;
- const KERNEL_BF16: &'static str = concat!("b", $name, "_bf16");
- const KERNEL_F16: &'static str = concat!("b", $name, "_f16");
- const KERNEL_F32: &'static str = concat!("b", $name, "_f32");
- const KERNEL_F64: &'static str = concat!("b", $name, "_f64");
- const KERNEL_U32: &'static str = concat!("b", $name, "_u32");
+ const KERNEL: &'static str = concat!("b", $name);
+ const V: Self = $op;
fn bf16(v1: bf16, v2: bf16) -> bf16 {
$e(v1, v2)
}
@@ -121,11 +112,8 @@ macro_rules! unary_op {
($op: ident, $name: literal, $a: ident, $e: expr) => {
impl UnaryOp for $op {
const NAME: &'static str = $name;
- const KERNEL_BF16: &'static str = concat!("u", $name, "_bf16");
- const KERNEL_F16: &'static str = concat!("u", $name, "_f16");
- const KERNEL_F32: &'static str = concat!("u", $name, "_f32");
- const KERNEL_F64: &'static str = concat!("u", $name, "_f64");
- const KERNEL_U32: &'static str = concat!("u", $name, "_u32");
+ const KERNEL: &'static str = concat!("u", $name);
+ const V: Self = $op;
fn bf16($a: bf16) -> bf16 {
$e
}
@@ -158,6 +146,7 @@ unary_op!(Sqrt, "sqrt", v, v.sqrt());
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
impl UnaryOp for Gelu {
const NAME: &'static str = "gelu";
+ const V: Self = Gelu;
fn bf16(v: bf16) -> bf16 {
bf16::from_f32_const(0.5)
* v
@@ -191,20 +180,13 @@ impl UnaryOp for Gelu {
fn u32(_: u32) -> u32 {
0
}
- const KERNEL_BF16: &'static str = "ugelu_bf16";
- const KERNEL_F16: &'static str = "ugelu_f16";
- const KERNEL_F32: &'static str = "ugelu_f32";
- const KERNEL_F64: &'static str = "ugelu_f64";
- const KERNEL_U32: &'static str = "ugelu_u32";
+ const KERNEL: &'static str = "ugelu";
}
impl UnaryOp for Relu {
const NAME: &'static str = "relu";
- const KERNEL_BF16: &'static str = "urelu_bf16";
- const KERNEL_F16: &'static str = "urelu_f16";
- const KERNEL_F32: &'static str = "urelu_f32";
- const KERNEL_F64: &'static str = "urelu_f64";
- const KERNEL_U32: &'static str = "urelu_u32";
+ const KERNEL: &'static str = "urelu";
+ const V: Self = Relu;
fn bf16(v: bf16) -> bf16 {
v.max(bf16::ZERO)
}
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index 7acf6dd0..4e630a58 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -9,11 +9,11 @@ pub enum Storage {
}
impl Storage {
- pub fn try_clone(&self) -> Result<Self> {
+ pub fn try_clone(&self, layout: &Layout) -> Result<Self> {
match self {
Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
Self::Cuda(storage) => {
- let storage = storage.try_clone()?;
+ let storage = storage.try_clone(layout)?;
Ok(Self::Cuda(storage))
}
}
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index f64bd6f2..4b9b3306 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -709,7 +709,7 @@ impl Tensor {
pub fn copy(&self) -> Result<Tensor> {
let tensor_ = Tensor_ {
id: TensorId::new(),
- storage: Arc::new(self.storage.try_clone()?),
+ storage: Arc::new(self.storage.try_clone(self.layout())?),
layout: self.layout.clone(),
op: None, // TODO
is_variable: false,