//! Implementation of Backend traits for CUDA device //! use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType}; pub use candle_kernels as kernels; pub use cudarc; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{ CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits, }; use half::{bf16, f16}; #[cfg(feature = "cudnn")] pub mod cudnn; mod device; mod error; mod utils; pub use device::{CudaDevice, DeviceId}; pub use error::{CudaError, WrapErr}; pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, Map3, S}; pub enum SlicePtrOrNull { Ptr(CudaSlice), Null, } unsafe impl DeviceRepr for &SlicePtrOrNull { fn as_kernel_param(&self) -> *mut std::ffi::c_void { match self { SlicePtrOrNull::Ptr(slice) => slice.as_kernel_param(), SlicePtrOrNull::Null => 0usize.as_kernel_param(), } } } impl SlicePtrOrNull { pub fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result { let ds = if l.is_contiguous() { SlicePtrOrNull::Null } else { SlicePtrOrNull::Ptr(dev.htod_copy([l.dims(), l.stride()].concat()).w()?) }; Ok(ds) } } #[derive(Debug)] pub enum CudaStorageSlice { U8(CudaSlice), U32(CudaSlice), I64(CudaSlice), BF16(CudaSlice), F16(CudaSlice), F32(CudaSlice), F64(CudaSlice), } struct Clone; impl Map1 for Clone { fn f( &self, s: &CudaSlice, _: &CudaDevice, _: &Layout, ) -> Result> { s.try_clone().w() } } pub fn kernel_name(root: &str) -> String { let dtype = T::DTYPE.as_str(); format!("{root}_{dtype}") } struct Affine(f64, f64); impl Map1 for Affine { fn f( &self, src: &CudaSlice, dev: &CudaDevice, layout: &Layout, ) -> Result> { let shape = layout.shape(); let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("affine"), kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }.w()?; let params = ( el, dims.len(), &ds, src, &out, T::from_f64(self.0), T::from_f64(self.1), ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } struct Elu(f64); impl Map1 for Elu { fn f( &self, src: &CudaSlice, dev: &CudaDevice, layout: &Layout, ) -> Result> { let shape = layout.shape(); let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("uelu"), kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } struct Im2Col1D { l_k: usize, stride: usize, dilation: usize, padding: usize, } impl Im2Col1D { fn l_out(&self, l: usize) -> usize { (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1 } } impl Map1 for Im2Col1D { fn f( &self, src: &CudaSlice, dev: &CudaDevice, layout: &Layout, ) -> Result> { let shape = layout.shape(); let dims = shape.dims(); let l_out = self.l_out(dims[2]); let dst_el = dims[0] * l_out * dims[1] * self.l_k; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("im2col1d"), kernels::CONV)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::(dst_el) }.w()?; let params = ( dst_el, l_out, self.l_k, self.stride, self.padding, self.dilation, &ds, src, &dst, ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(dst) } } #[allow(unused)] struct Im2Col { h_k: usize, w_k: usize, stride: usize, dilation: usize, padding: usize, } impl Im2Col { #[allow(unused)] fn hw_out(&self, h: usize, w: usize) -> (usize, usize) { let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1; let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1; (h_out, w_out) } } impl Map1 for Im2Col { fn f( &self, src: &CudaSlice, dev: &CudaDevice, layout: &Layout, ) -> Result> { let shape = layout.shape(); let dims = shape.dims(); let (h_out, w_out) = self.hw_out(dims[2], dims[3]); let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("im2col"), kernels::CONV)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::(dst_el) }.w()?; let params = ( dst_el, h_out, w_out, self.h_k, self.w_k, self.stride, self.padding, self.dilation, &ds, src, &dst, ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(dst) } } struct Powf(f64); impl Map1 for Powf { fn f( &self, src: &CudaSlice, dev: &CudaDevice, layout: &Layout, ) -> Result> { let shape = layout.shape(); let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("upowf"), kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } struct FastReduce<'a>(&'a [usize], ReduceOp); impl<'a> Map1Any for FastReduce<'a> { fn f) -> S>( &self, src: &CudaSlice, dev: &CudaDevice, layout: &Layout, wrap: W, ) -> Result { let src_stride = layout.stride(); let src_dims = layout.shape().dims(); let src_el: usize = src_dims.iter().product(); // Source dims and strides with the sum dims at the end. let mut dims = vec![]; let mut stride = vec![]; let mut dst_el: usize = 1; for (dim_idx, &d) in src_dims.iter().enumerate() { if !self.0.contains(&dim_idx) { dst_el *= d; dims.push(d); stride.push(src_stride[dim_idx]); } } for &dim_idx in self.0.iter() { dims.push(src_dims[dim_idx]); stride.push(src_stride[dim_idx]); } let el_to_sum_per_block = src_el / dst_el; // The reduction loop requires the shared array to be properly initialized and for // this we want the number of threads to be a power of two. let block_dim = usize::min(1024, el_to_sum_per_block).next_power_of_two(); let cfg = LaunchConfig { // TODO: Maybe use grid_y if the output is too large? // TODO: Specialized implementation when reducing on no or all dimensions or when // reducing only aggregate a small number of elements together. grid_dim: (dst_el as u32, 1, 1), block_dim: (block_dim as u32, 1, 1), shared_mem_bytes: 0, }; let ds = dev .htod_copy([dims.as_slice(), stride.as_slice()].concat()) .w()?; let src = &src.slice(layout.start_offset()..); let (name, check_empty, return_index) = match self.1 { ReduceOp::Sum => ("fast_sum", false, false), ReduceOp::Min => ("fast_min", true, false), ReduceOp::Max => ("fast_max", true, false), ReduceOp::ArgMin => ("fast_argmin", true, true), ReduceOp::ArgMax => ("fast_argmax", true, true), }; if check_empty && layout.shape().elem_count() == 0 { Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? } let func = dev.get_or_load_func(&kernel_name::(name), kernels::REDUCE)?; if return_index { // SAFETY: filled in by the follow up kernel. let out = unsafe { dev.alloc::(dst_el) }.w()?; let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(S::U32(out)) } else { // SAFETY: filled in by the follow up kernel. let out = unsafe { dev.alloc::(dst_el) }.w()?; let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(wrap(out)) } } } impl Map1 for U { fn f( &self, src: &CudaSlice, dev: &CudaDevice, layout: &Layout, ) -> Result> { let shape = layout.shape(); let dims = shape.dims(); let el_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }.w()?; let params = (el_count, dims.len(), &ds, src, &out); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } struct IndexSelect<'a>(&'a CudaStorage, &'a Layout, usize); impl<'a> Map1 for IndexSelect<'a> { fn f( &self, src: &CudaSlice, dev: &CudaDevice, src_l: &Layout, ) -> Result> { let ids_l = &self.1; let (name, ids) = match &self.0.slice { CudaStorageSlice::U32(slice) => { ("is_u32", *slice.slice(ids_l.start_offset()..).device_ptr()) } CudaStorageSlice::U8(slice) => { ("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr()) } CudaStorageSlice::I64(slice) => { ("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr()) } _ => Err(CudaError::UnexpectedDType { msg: "index_select ids should be u8 or u32", expected: DType::U32, got: self.0.dtype(), }) .w()?, }; let ids_shape = ids_l.shape(); let ids_dims = ids_shape.dims(); let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?; let src = match src_l.contiguous_offsets() { Some((o1, o2)) => src.slice(o1..o2), None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?, }; let left_size: usize = src_l.dims()[..self.2].iter().product(); let right_size: usize = src_l.dims()[self.2 + 1..].iter().product(); let src_dim_size = src_l.dims()[self.2]; let ids_dim_size = ids_shape.elem_count(); let dst_el = ids_shape.elem_count() * left_size * right_size; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(dst_el) }.w()?; let params = ( dst_el, ids_dims.len(), &ds, ids, &src, &out, left_size, src_dim_size, ids_dim_size, right_size, ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } struct Gather<'a>(&'a CudaStorage, &'a Layout, usize); impl<'a> Map1 for Gather<'a> { fn f( &self, src: &CudaSlice, dev: &CudaDevice, src_l: &Layout, ) -> Result> { let ids = &self.0; let ids_l = &self.1; let dim = self.2; let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() { Some(o12) => o12, None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?, }; let (name, ids) = match &ids.slice { CudaStorageSlice::U32(slice) => { ("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr()) } CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::I64(slice) => { ("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr()) } _ => Err(CudaError::UnexpectedDType { msg: "gather ids should be u8/u32/i64", expected: DType::U32, got: ids.dtype(), })?, }; let el = ids_l.shape().elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let src = match src_l.contiguous_offsets() { Some((o1, o2)) => src.slice(o1..o2), None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?, }; let left_sz: usize = src_l.dims()[..dim].iter().product(); let right_sz: usize = src_l.dims()[dim + 1..].iter().product(); let src_dim_sz = src_l.dims()[dim]; let ids_dim_sz = ids_l.dims()[dim]; let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }.w()?; let params = ( el, ids, &src, &out, left_sz, src_dim_sz, ids_dim_sz, right_sz, ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } struct IndexAdd<'a>(&'a CudaStorage, &'a Layout, usize); impl<'a> Map2InPlace for IndexAdd<'a> { fn f( &self, dst: &mut CudaSlice, dst_shape: &Shape, src: &CudaSlice, src_l: &Layout, dev: &CudaDevice, ) -> Result<()> { let ids = &self.0; let ids_l = &self.1; let dim = self.2; let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() { Some(o12) => o12, None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, }; let (name, ids) = match &ids.slice { CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), _ => Err(CudaError::UnexpectedDType { msg: "index-add ids should be u8/u32/i64", expected: DType::U32, got: ids.dtype(), })?, }; let src = match src_l.contiguous_offsets() { Some((o1, o2)) => src.slice(o1..o2), None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, }; let left_sz: usize = src_l.dims()[..dim].iter().product(); let right_sz: usize = src_l.dims()[dim + 1..].iter().product(); let src_dim_sz = src_l.dims()[dim]; let dst_dim_sz = dst_shape.dims()[dim]; let ids_dim_sz = ids_l.dims()[0]; let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32); let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; // SAFETY: Set later by running the kernel. let params = ( ids, ids_dim_sz, &src, dst, left_sz, src_dim_sz, dst_dim_sz, right_sz, ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(()) } } struct ScatterAdd<'a>(&'a CudaStorage, &'a Layout, usize); impl<'a> Map2InPlace for ScatterAdd<'a> { fn f( &self, dst: &mut CudaSlice, dst_shape: &Shape, src: &CudaSlice, src_l: &Layout, dev: &CudaDevice, ) -> Result<()> { let ids = &self.0; let ids_l = &self.1; let dim = self.2; let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() { Some(o12) => o12, None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, }; let (name, ids) = match &ids.slice { CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), _ => Err(CudaError::UnexpectedDType { msg: "scatter-add ids should be u8/u32/i64", expected: DType::U32, got: ids.dtype(), })?, }; let src = match src_l.contiguous_offsets() { Some((o1, o2)) => src.slice(o1..o2), None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?, }; let left_sz: usize = src_l.dims()[..dim].iter().product(); let right_sz: usize = src_l.dims()[dim + 1..].iter().product(); let src_dim_sz = src_l.dims()[dim]; let dst_dim_sz = dst_shape.dims()[dim]; let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32); let func = dev.get_or_load_func(&kernel_name::(name), kernels::INDEXING)?; // SAFETY: Set later by running the kernel. let params = (ids, &src, dst, left_sz, src_dim_sz, dst_dim_sz, right_sz); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(()) } } struct Conv1D<'a>(&'a crate::conv::ParamsConv1D); impl<'a> Map2 for Conv1D<'a> { fn f( &self, inp: &CudaSlice, inp_l: &Layout, k: &CudaSlice, k_l: &Layout, dev: &CudaDevice, ) -> Result> { // Kernel shape: (c_out, c_in_k, k_size) // Input shape: (b_size, c_in, l_in) or (c_in, l_in) let p = &self.0; let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(k_l.start_offset()..); let shape = inp_l.shape(); let dims = shape.dims(); let el = shape.elem_count(); let l_out = p.l_out(); let dst_el = p.c_out * l_out * p.b_size; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::("conv1d"), kernels::CONV)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(dst_el) }.w()?; let ds = if dims.len() == 3 { [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat() } else if dims.len() == 2 { [&[1], dims, &[1], inp_l.stride(), k_l.dims(), k_l.stride()].concat() } else { crate::bail!("unexpected input shape for conv1d {dims:?}") }; let ds = dev.htod_copy(ds).w()?; let params = ( el, l_out, p.stride, p.padding, p.dilation, &ds, inp, k, &out, ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } struct Conv2D<'a>(&'a crate::conv::ParamsConv2D); impl<'a> Map2 for Conv2D<'a> { fn f( &self, inp: &CudaSlice, inp_l: &Layout, k: &CudaSlice, k_l: &Layout, dev: &CudaDevice, ) -> Result> { // Kernel shape: (c_out, c_in_k, h_k, w_k) // Input shape: (b_size, c_in, h_in, w_in) let p = &self.0; let (out_w, out_h) = (p.out_w(), p.out_h()); let dst_el = p.c_out * out_w * out_h * p.b_size; let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(k_l.start_offset()..); let shape = inp_l.shape(); let dims = shape.dims(); let el = shape.elem_count(); // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(dst_el) }.w()?; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::("conv2d"), kernels::CONV)?; let ds = if dims.len() == 4 { [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat() } else { crate::bail!("unexpected input shape for conv2d {dims:?}") }; let ds = dev.htod_copy(ds).w()?; let params = ( el, out_w, out_h, p.stride, p.padding, p.dilation, &ds, inp, k, &out, ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } struct Col2Im1D { stride: usize, } impl Map1 for Col2Im1D { fn f( &self, col: &CudaSlice, dev: &CudaDevice, l: &Layout, ) -> Result> { let (b_size, l_in, c_out, k_size) = l.shape().dims4()?; let stride = self.stride; let l_out = (l_in - 1) * stride + k_size; let dst_el = b_size * c_out * l_out; let mut im = unsafe { dev.alloc::(dst_el) }.w()?; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let params = (dst_el, l_out, l_in, c_out, k_size, stride, col, &mut im); let func = dev.get_or_load_func(&kernel_name::("col2im1d"), kernels::CONV)?; unsafe { func.launch(cfg, params) }.w()?; Ok(im) } } struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D); impl<'a> Map2 for ConvTranspose1D<'a> { fn f( &self, inp: &CudaSlice, inp_l: &Layout, k: &CudaSlice, k_l: &Layout, dev: &CudaDevice, ) -> Result> { // Kernel shape: (c_in_k, c_out, l_k) // Input shape: (b_size, c_in, l_in) let p = &self.0; let l_out = p.l_out(); let dst_el = p.c_out * l_out * p.b_size; let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(k_l.start_offset()..); let shape = inp_l.shape(); let dims = shape.dims(); let el = shape.elem_count(); // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(dst_el) }.w()?; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::("conv_transpose1d"), kernels::CONV)?; let ds = if dims.len() == 3 { [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat() } else { crate::bail!("unexpected input shape for conv_transpose1d {dims:?}") }; let ds = dev.htod_copy(ds).w()?; let params = ( el, l_out, p.stride, p.padding, p.output_padding, p.dilation, &ds, inp, k, &out, ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D); impl<'a> Map2 for ConvTranspose2D<'a> { fn f( &self, inp: &CudaSlice, inp_l: &Layout, k: &CudaSlice, k_l: &Layout, dev: &CudaDevice, ) -> Result> { // Kernel shape: (c_in_k, c_out, h_k, w_k) // Input shape: (b_size, c_in, h_in, w_in) let p = &self.0; let (out_w, out_h) = (p.out_w(), p.out_h()); let dst_el = p.c_out * out_w * out_h * p.b_size; let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(k_l.start_offset()..); let shape = inp_l.shape(); let dims = shape.dims(); let el = shape.elem_count(); // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(dst_el) }.w()?; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::("conv_transpose2d"), kernels::CONV)?; let ds = if dims.len() == 4 { [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat() } else { crate::bail!("unexpected input shape for conv_transpose2d {dims:?}") }; let ds = dev.htod_copy(ds).w()?; let params = ( el, out_w, out_h, p.stride, p.padding, p.output_padding, p.dilation, &ds, inp, k, &out, ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } enum PoolOp { Max, Avg, } struct Pool2D { w_k: usize, h_k: usize, w_stride: usize, h_stride: usize, op: PoolOp, } impl Map1 for Pool2D { fn f( &self, inp: &CudaSlice, dev: &CudaDevice, inp_l: &Layout, ) -> Result> { // Input shape: (b_size, c, h, w) let inp = &inp.slice(inp_l.start_offset()..); let shape = inp_l.shape(); let dims = shape.dims(); let ds = if dims.len() == 4 { [dims, inp_l.stride()].concat() } else { crate::bail!("unexpected input shape for pool {dims:?}") }; let el = shape.elem_count(); let out_w = (dims[2] - self.w_k) / self.w_stride + 1; let out_h = (dims[3] - self.h_k) / self.h_stride + 1; let dst_el = out_w * out_h * dims[0] * dims[1]; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let kname = match self.op { PoolOp::Max => "max_pool2d", PoolOp::Avg => "avg_pool2d", }; let func = dev.get_or_load_func(&kernel_name::(kname), kernels::CONV)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(dst_el) }.w()?; let ds = dev.htod_copy(ds).w()?; let params = ( el, self.w_k, self.h_k, self.w_stride, self.h_stride, &ds, inp, &out, ); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } struct UpsampleNearest2D(usize, usize); impl Map1 for UpsampleNearest2D { fn f( &self, inp: &CudaSlice, dev: &CudaDevice, inp_l: &Layout, ) -> Result> { // Input shape: (b_size, c, h, w) let inp = &inp.slice(inp_l.start_offset()..); let shape = inp_l.shape(); let dims = shape.dims(); let ds = if dims.len() == 4 { [dims, inp_l.stride()].concat() } else { crate::bail!("unexpected input shape for upsample {dims:?}") }; let (out_w, out_h) = (self.0, self.1); let dst_el = out_w * out_h * dims[0] * dims[1]; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::("upsample_nearest2d"), kernels::CONV)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(dst_el) }.w()?; let ds = dev.htod_copy(ds).w()?; let scale_w = dims[2] as f64 / out_w as f64; let scale_h = dims[3] as f64 / out_h as f64; let params = (out_w, out_h, scale_w, scale_h, &ds, inp, &out); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } struct WhereCond<'a>(&'a CudaStorage, &'a Layout); impl<'a> Map2 for WhereCond<'a> { fn f( &self, t: &CudaSlice, layout_t: &Layout, f: &CudaSlice, layout_f: &Layout, dev: &CudaDevice, ) -> Result> { let ids_l = &self.1; let (ids, name) = match &self.0.slice { CudaStorageSlice::U8(slice) => { let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); (ptr, "where_u8") } CudaStorageSlice::U32(slice) => { let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); (ptr, "where_u32") } CudaStorageSlice::I64(slice) => { let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); (ptr, "where_i64") } _ => Err(CudaError::UnexpectedDType { msg: "where conditions should be u8/u32/i64", expected: DType::U32, got: self.0.dtype(), }) .w()?, }; let shape = ids_l.shape(); let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let ds = dev .htod_copy([dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat()) .w()?; let t = &t.slice(layout_t.start_offset()..); let f = &f.slice(layout_f.start_offset()..); let func = dev.get_or_load_func(&kernel_name::(name), kernels::TERNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, ids, t, f, &out); // SAFETY: ffi unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } impl Map2 for U { fn f( &self, lhs: &CudaSlice, lhs_l: &Layout, rhs: &CudaSlice, rhs_l: &Layout, dev: &CudaDevice, ) -> Result> { let shape = lhs_l.shape(); let dims = shape.dims(); let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() { SlicePtrOrNull::Null } else { SlicePtrOrNull::Ptr( dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat()) .w()?, ) }; let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), kernels::BINARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(elem_count) }.w()?; let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); // SAFETY: ffi unsafe { func.launch(cfg, params) }.w()?; Ok(out) } } struct Cmp(CmpOp); impl Map2Any for Cmp { fn f( &self, lhs: &CudaSlice, lhs_l: &Layout, rhs: &CudaSlice, rhs_l: &Layout, dev: &CudaDevice, ) -> Result { let shape = lhs_l.shape(); let dims = shape.dims(); let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() { SlicePtrOrNull::Null } else { SlicePtrOrNull::Ptr( dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat()) .w()?, ) }; let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let name = match self.0 { CmpOp::Eq => "eq", CmpOp::Ne => "ne", CmpOp::Lt => "lt", CmpOp::Le => "le", CmpOp::Gt => "gt", CmpOp::Ge => "ge", }; let func = dev.get_or_load_func(&kernel_name::(name), kernels::BINARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(elem_count) }.w()?; let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out); // SAFETY: ffi unsafe { func.launch(cfg, params) }.w()?; Ok(S::U8(out)) } } fn slice_src_and_dst<'a, T>( src: &'a CudaSlice, src_l: &Layout, dst: &'a mut CudaSlice, dst_offset: usize, ) -> ( cudarc::driver::CudaView<'a, T>, cudarc::driver::CudaViewMut<'a, T>, ) { let src_offset = src_l.start_offset(); let to_copy = dst .len() .saturating_sub(dst_offset) .min(src.len().saturating_sub(src_offset)); let src = src.slice(src_offset..src_offset + to_copy); let dst = dst.slice_mut(dst_offset..dst_offset + to_copy); (src, dst) } #[derive(Debug)] pub struct CudaStorage { pub slice: CudaStorageSlice, pub device: CudaDevice, } pub trait CudaDType: Sized { fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice>; fn wrap_cuda_slice(s: CudaSlice, dev: CudaDevice) -> CudaStorage; } macro_rules! cuda_dtype { ($ty:ty, $dtype:ident) => { impl CudaDType for $ty { fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice> { match &s.slice { CudaStorageSlice::$dtype(data) => Ok(&data), _ => Err(crate::Error::UnexpectedDType { expected: DType::$dtype, got: s.dtype(), msg: "unexpected dtype", } .bt()), } } fn wrap_cuda_slice(slice: CudaSlice, device: CudaDevice) -> CudaStorage { let slice = CudaStorageSlice::$dtype(slice); CudaStorage { slice, device } } } }; } cuda_dtype!(u8, U8); cuda_dtype!(u32, U32); cuda_dtype!(i64, I64); cuda_dtype!(f16, F16); cuda_dtype!(bf16, BF16); cuda_dtype!(f32, F32); cuda_dtype!(f64, F64); impl CudaStorage { pub fn wrap_cuda_slice(slice: CudaSlice, device: CudaDevice) -> CudaStorage { T::wrap_cuda_slice(slice, device) } pub fn as_cuda_slice(&self) -> Result<&CudaSlice> { T::as_cuda_slice(self) } } fn gemm_config( alpha: T, beta: T, (b, m, n, k): (usize, usize, usize, usize), lhs_l: &Layout, rhs_l: &Layout, ) -> Result> { // https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm use cudarc::cublas::sys::cublasOperation_t; let lhs_stride = lhs_l.stride(); let rhs_stride = rhs_l.stride(); let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; // The a tensor has dims batching, k, n (rhs) // We also allow for the case where the stride on the minor dimension is not as expected but // there is a single element. let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { (n as i32, cublasOperation_t::CUBLAS_OP_N) } else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) { (k as i32, cublasOperation_t::CUBLAS_OP_T) } else { Err(CudaError::MatMulNonContiguous { lhs_stride: lhs_l.clone(), rhs_stride: rhs_l.clone(), mnk: (m, n, k), })? }; // The b tensor has dims batching, m, k (lhs) // We also allow for the case where the stride on the minor dimension is not as expected but // there is a single element. let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { (k as i32, cublasOperation_t::CUBLAS_OP_N) } else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) { (m as i32, cublasOperation_t::CUBLAS_OP_T) } else { Err(CudaError::MatMulNonContiguous { lhs_stride: lhs_l.clone(), rhs_stride: rhs_l.clone(), mnk: (m, n, k), })? }; // The setup below was copied from: // https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531 let gemm = GemmConfig { alpha, beta, m: n as i32, n: m as i32, k: k as i32, lda, ldb, ldc: n as i32, transa, transb, }; let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] { [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, [_, stride] if lhs_l.dims()[0] == 1 => stride, [stride, _] if lhs_l.dims()[1] == 1 => stride, [stride] => stride, [] => m * k, _ => Err(CudaError::MatMulNonContiguous { lhs_stride: lhs_l.clone(), rhs_stride: rhs_l.clone(), mnk: (m, n, k), })?, }; let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] { [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, [_, stride] if rhs_l.dims()[0] == 1 => stride, [stride, _] if rhs_l.dims()[1] == 1 => stride, [stride] => stride, [] => n * k, _ => Err(CudaError::MatMulNonContiguous { lhs_stride: lhs_l.clone(), rhs_stride: rhs_l.clone(), mnk: (m, n, k), })?, }; Ok(StridedBatchedConfig { batch_size: b as i32, gemm, stride_a: stride_a as i64, stride_b: stride_b as i64, stride_c: (m * n) as i64, }) } impl BackendStorage for CudaStorage { type Device = CudaDevice; fn try_clone(&self, layout: &Layout) -> Result { let slice = Clone.map(&self.slice, self.device(), layout)?; let device = self.device.clone(); Ok(Self { slice, device }) } fn dtype(&self) -> DType { match self.slice { CudaStorageSlice::U8(_) => DType::U8, CudaStorageSlice::U32(_) => DType::U32, CudaStorageSlice::I64(_) => DType::I64, CudaStorageSlice::BF16(_) => DType::BF16, CudaStorageSlice::F16(_) => DType::F16, CudaStorageSlice::F32(_) => DType::F32, CudaStorageSlice::F64(_) => DType::F64, } } fn device(&self) -> &CudaDevice { &self.device } fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { let shape = layout.shape(); let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let dev = self.device(); let ds = SlicePtrOrNull::params_from_layout(dev, layout)?; let start_o = layout.start_offset(); // This returns an i64 rather than a &i64, this is useful to get around some temporary // lifetime issue and is safe as long as self.slice does not go out of scope before inp // is used. let inp = match &self.slice { CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F64(inp) => *inp.slice(start_o..).device_ptr(), }; let inp = &inp; let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str()); let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?; let slice = match dtype { DType::U8 => { let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U8(out) } DType::U32 => { let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U32(out) } DType::I64 => { let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::I64(out) } DType::BF16 => { let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::BF16(out) } DType::F16 => { let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F16(out) } DType::F32 => { let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F32(out) } DType::F64 => { let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F64(out) } }; Ok(Self { slice, device: dev.clone(), }) } fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { let device = self.device().clone(); let slice = Affine(mul, add).map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } fn powf(&self, layout: &Layout, e: f64) -> Result { let device = self.device().clone(); let slice = Powf(e).map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } fn elu(&self, layout: &Layout, alpha: f64) -> Result { let device = self.device().clone(); let slice = Elu(alpha).map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { let device = self.device().clone(); let slice = FastReduce(sum_dims, op).map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result { let device = self.device().clone(); let slice = Cmp(op).map(&self.slice, lhs_l, &rhs.slice, rhs_l, &device)?; Ok(Self { slice, device }) } fn unary_impl(&self, layout: &Layout) -> Result { let device = self.device().clone(); let slice = U::V.map(&self.slice, &device, layout)?; Ok(Self { slice, device }) } fn binary_impl( &self, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout, ) -> Result { let device = self.device().clone(); let slice = B::V.map(&self.slice, lhs_l, &rhs.slice, rhs_l, &device)?; Ok(Self { slice, device }) } fn to_cpu_storage(&self) -> Result { match &self.slice { CudaStorageSlice::U8(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::U8(cpu_storage)) } CudaStorageSlice::U32(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::U32(cpu_storage)) } CudaStorageSlice::I64(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::I64(cpu_storage)) } CudaStorageSlice::BF16(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::BF16(cpu_storage)) } CudaStorageSlice::F16(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::F16(cpu_storage)) } CudaStorageSlice::F32(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::F32(cpu_storage)) } CudaStorageSlice::F64(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::F64(cpu_storage)) } } } fn where_cond( &self, layout: &Layout, t: &Self, t_l: &Layout, f: &Self, f_l: &Layout, ) -> Result { let device = self.device().clone(); let slice = WhereCond(self, layout).map(&t.slice, t_l, &f.slice, f_l, &device)?; Ok(Self { slice, device }) } fn conv1d( &self, l: &Layout, kernel: &Self, kernel_l: &Layout, params: &crate::conv::ParamsConv1D, ) -> Result { const USE_IM2COL_CONV1D: bool = true; let device = self.device().clone(); if !USE_IM2COL_CONV1D { let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; return Ok(Self { slice, device }); } let col = Im2Col1D { l_k: params.k_size, stride: params.stride, dilation: params.dilation, padding: params.padding, } .map(&self.slice, &device, l)?; let col = Self { slice: col, device }; let l_out = params.l_out(); let b = params.b_size; let n = params.c_out; let k = params.k_size * params.c_in; let m = l_out; let col_l = Layout::contiguous((b, m, k)); let res = if kernel_l.is_contiguous() { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. let mut kernel_c = unsafe { self.device() .alloc_uninit(kernel_l.shape(), kernel.dtype())? }; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?; let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; res.copy_strided_src(&mut res_t, 0, &res_l)?; Ok(res_t) } fn conv_transpose1d( &self, l: &Layout, kernel: &Self, kernel_l: &Layout, params: &crate::conv::ParamsConvTranspose1D, ) -> Result { const USE_COL2IM_CONV1D_TR: bool = true; let device = self.device().clone(); let can_use_col2im = kernel_l.is_contiguous() && params.dilation == 1 && params.padding == 0 && params.output_padding == 0; let slice = if USE_COL2IM_CONV1D_TR && can_use_col2im { let (b_size, c_in, l_in) = l.shape().dims3()?; let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?; if !kernel_l.is_contiguous() { crate::bail!( "convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}" ) } if c_in != c_in2 { crate::bail!( "convtr1d: shape mismatch on c_in {:?} {:?}", l.shape(), kernel_l.shape() ) } let col = { // This merges the last two dimensions of the kernel together. let kernel_l_mm = Layout::new( (b_size, c_in, k_size * c_out).into(), vec![0, k_size * c_out, 1], kernel_l.start_offset(), ); self.matmul( kernel, ( b_size, /* m */ l_in, /* n */ c_out * k_size, /* k */ c_in, ), &l.transpose(1, 2)?, &kernel_l_mm, )? }; let col_l = Layout::contiguous((b_size, l_in, c_out, k_size)); Col2Im1D { stride: params.stride, } .map(&col.slice, &device, &col_l)? } else { ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)? }; Ok(Self { slice, device }) } #[cfg(not(feature = "cudnn"))] fn conv2d( &self, l: &Layout, kernel: &Self, kernel_l: &Layout, params: &crate::conv::ParamsConv2D, ) -> Result { const USE_IM2COL_CONV2D: bool = true; let device = self.device().clone(); if !USE_IM2COL_CONV2D { let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; return Ok(Self { slice, device }); } let col = Im2Col { h_k: params.k_h, w_k: params.k_w, stride: params.stride, dilation: params.dilation, padding: params.padding, } .map(&self.slice, &device, l)?; let col = Self { slice: col, device }; let h_out = params.out_h(); let w_out = params.out_w(); let b = params.b_size; let n = params.c_out; let k = params.k_h * params.k_w * params.c_in; let m = h_out * w_out; let col_l = Layout::contiguous((b, m, k)); let res = if kernel_l.is_contiguous() { let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? } else { // Make the kernel contiguous if not already the case. let mut kernel_c = unsafe { self.device() .alloc_uninit(kernel_l.shape(), kernel.dtype())? }; kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) .transpose(1, 2)? .broadcast_as((b, k, n))?; col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? }; let res_l = Layout::contiguous((b, h_out, w_out, n)) .transpose(1, 2)? .transpose(1, 3)?; let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? }; res.copy_strided_src(&mut res_t, 0, &res_l)?; Ok(res_t) } #[cfg(feature = "cudnn")] fn conv2d( &self, inp_l: &Layout, kernel: &Self, kernel_l: &Layout, params: &crate::conv::ParamsConv2D, ) -> Result { let device = self.device().clone(); if !kernel_l.is_contiguous() { let slice = Conv2D(params).map(&self.slice, inp_l, &kernel.slice, kernel_l, &device)?; return Ok(Self { slice, device }); } let (out_w, out_h) = (params.out_w(), params.out_h()); let dst_el = params.c_out * out_w * out_h * params.b_size; let slice = match (&self.slice, &kernel.slice) { (S::U8(inp), S::U8(k)) => { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::(dst_el) }.w()?; crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::U8(out) } (S::BF16(inp), S::BF16(k)) => { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::(dst_el) }.w()?; // Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16" // version. // https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88 crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::BF16(out) } (S::F16(inp), S::F16(k)) => { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::(dst_el) }.w()?; crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::F16(out) } (S::F32(inp), S::F32(k)) => { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::(dst_el) }.w()?; crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::F32(out) } (S::F64(inp), S::F64(k)) => { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::(dst_el) }.w()?; crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::F64(out) } (S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?, (S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?, _ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?, }; Ok(Self { slice, device }) } fn conv_transpose2d( &self, l: &Layout, kernel: &Self, kernel_l: &Layout, params: &crate::conv::ParamsConvTranspose2D, ) -> Result { let device = self.device().clone(); let slice = ConvTranspose2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; Ok(Self { slice, device }) } fn avg_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result { let device = self.device().clone(); let slice = Pool2D { w_k: k.0, h_k: k.1, w_stride: stride.0, h_stride: stride.1, op: PoolOp::Avg, } .map(&self.slice, &device, l)?; Ok(Self { slice, device }) } fn max_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result { let device = self.device().clone(); let slice = Pool2D { w_k: k.0, h_k: k.1, w_stride: stride.0, h_stride: stride.1, op: PoolOp::Max, } .map(&self.slice, &device, l)?; Ok(Self { slice, device }) } fn upsample_nearest1d(&self, _: &Layout, _out_sz: usize) -> Result { crate::bail!("upsample-nearest1d is not supported on cuda") } fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result { let device = self.device().clone(); let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?; Ok(Self { slice, device }) } fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result { let device = self.device().clone(); let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?; Ok(Self { slice, device }) } fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result { let device = self.device().clone(); let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?; Ok(Self { slice, device }) } fn scatter_add( &self, l: &Layout, ids: &Self, ids_l: &Layout, src: &Self, src_l: &Layout, dim: usize, ) -> Result { let device = self.device().clone(); let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? }; self.copy_strided_src(&mut acc, 0, l)?; ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?; Ok(acc) } fn index_add( &self, l: &Layout, ids: &Self, ids_l: &Layout, src: &Self, src_l: &Layout, dim: usize, ) -> Result { let device = self.device().clone(); let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? }; self.copy_strided_src(&mut acc, 0, l)?; IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?; Ok(acc) } fn matmul( &self, rhs: &Self, (b, m, n, k): (usize, usize, usize, usize), lhs_l: &Layout, rhs_l: &Layout, ) -> Result { let elem_count = b * m * n; let dev = &self.device; let slice = match (&self.slice, &rhs.slice) { (CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::(elem_count) }.w()?; unsafe { gemm_strided_batched_bf16(&self.device.blas, cfg, rhs, lhs, &mut out) } .w()?; CudaStorageSlice::BF16(out) } (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::(elem_count) }.w()?; unsafe { gemm_strided_batched_f16(&self.device.blas, cfg, rhs, lhs, &mut out) } .w()?; CudaStorageSlice::F16(out) } (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::(elem_count) }.w()?; unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) } .w()?; CudaStorageSlice::F32(out) } (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::(elem_count) }.w()?; unsafe { self.device .blas .gemm_strided_batched(cfg, rhs, lhs, &mut out) } .w()?; CudaStorageSlice::F64(out) } _ => Err(CudaError::InternalError("dtype mismatch in matmul op"))?, }; let device = dev.clone(); Ok(Self { slice, device }) } fn copy2d( &self, dst: &mut Self, d1: usize, d2: usize, src_s: usize, dst_s: usize, src_o: usize, dst_o: usize, ) -> Result<()> { let dev = &self.device; let d1 = d1 as u32; let d2 = d2 as u32; // Nothing to copy so we exit early to avoid launching a kernel and some potential invalid // argument with a null pointer. if d1 == 0 || d2 == 0 { return Ok(()); } let dst_s = dst_s as u32; let src_s = src_s as u32; let (src, dst, kname) = match (&self.slice, &mut dst.slice) { (S::U8(s), S::U8(d)) => ( *s.slice(src_o..).device_ptr(), *d.slice(dst_o..).device_ptr(), "copy2d_u8", ), (S::U32(s), S::U32(d)) => ( *s.slice(src_o..).device_ptr(), *d.slice(dst_o..).device_ptr(), "copy2d_u32", ), (S::I64(s), S::I64(d)) => ( *s.slice(src_o..).device_ptr(), *d.slice(dst_o..).device_ptr(), "copy2d_i64", ), (S::BF16(s), S::BF16(d)) => ( *s.slice(src_o..).device_ptr(), *d.slice(dst_o..).device_ptr(), "copy2d_bf16", ), (S::F16(s), S::F16(d)) => ( *s.slice(src_o..).device_ptr(), *d.slice(dst_o..).device_ptr(), "copy2d_f16", ), (S::F32(s), S::F32(d)) => ( *s.slice(src_o..).device_ptr(), *d.slice(dst_o..).device_ptr(), "copy2d_f32", ), (S::F64(s), S::F64(d)) => ( *s.slice(src_o..).device_ptr(), *d.slice(dst_o..).device_ptr(), "copy2d_f64", ), _ => Err(CudaError::InternalError("dtype mismatch in copy2d"))?, }; let func = dev.get_or_load_func(kname, kernels::FILL)?; let cfg = LaunchConfig::for_num_elems(d1 * d2); let params = (src, dst, d1, d2, src_s, dst_s); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(()) } fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { let src_shape = src_l.shape(); let dims = src_shape.dims(); let el_count = src_shape.elem_count(); if el_count == 0 { return Ok(()); } let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = &self.device; let ds = SlicePtrOrNull::params_from_layout(dev, src_l)?; match (&self.slice, &mut dst.slice) { (CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst).w()? } else { let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?; // SAFETY: Set later by running the kernel. let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()? } } (CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst).w()? } else { let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?; // SAFETY: Set later by running the kernel. let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()? } } (CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst).w()? } else { let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?; // SAFETY: Set later by running the kernel. let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()? } } (CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst).w()? } else { let func = dev.get_or_load_func("ucopy_u8", kernels::UNARY)?; // SAFETY: Set later by running the kernel. let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()? } } (CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst).w()? } else { let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?; // SAFETY: Set later by running the kernel. let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()? } } (CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst).w()? } else { let func = dev.get_or_load_func("ucopy_i64", kernels::UNARY)?; // SAFETY: Set later by running the kernel. let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()? } } (CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst).w()? } else { let func = dev.get_or_load_func("ucopy_f64", kernels::UNARY)?; // SAFETY: Set later by running the kernel. let params = (el_count, dims.len(), &ds, &src, &mut dst); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; } } _ => Err(CudaError::InternalError( "dtype mismatch in copy_strided op", ))?, } Ok(()) } } // Default for the reduced precision setting is false, similar to pytorch. // https://github.com/pytorch/pytorch/issues/123157 static MM_F16_REDUCED_PRECISION: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false); static MM_BF16_REDUCED_PRECISION: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false); static MM_F32_REDUCED_PRECISION: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false); /// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are /// allowed with f32 GEMMs. pub fn gemm_reduced_precision_f32() -> bool { MM_F32_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed) } /// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are /// allowed with f32 GEMMs. pub fn set_gemm_reduced_precision_f32(b: bool) { MM_F32_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed) } /// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are /// allowed with f16 GEMMs. pub fn gemm_reduced_precision_f16() -> bool { MM_F16_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed) } /// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are /// allowed with f16 GEMMs. pub fn set_gemm_reduced_precision_f16(b: bool) { MM_F16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed) } /// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are /// allowed with bf16 GEMMs. pub fn gemm_reduced_precision_bf16() -> bool { MM_BF16_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed) } /// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are /// allowed with bf16 GEMMs. pub fn set_gemm_reduced_precision_bf16(b: bool) { MM_BF16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed) } unsafe fn gemm_strided_batched_f32( cublas: &cudarc::cublas::CudaBlas, cfg: StridedBatchedConfig, a: &cudarc::driver::CudaView, b: &cudarc::driver::CudaView, c: &mut CudaSlice, ) -> std::result::Result<(), cudarc::cublas::result::CublasError> { use cudarc::cublas::sys; use cudarc::driver::DevicePtrMut; let compute_type = if gemm_reduced_precision_f32() { sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32 } else { sys::cublasComputeType_t::CUBLAS_COMPUTE_32F }; let alpha = &cfg.gemm.alpha as *const f32 as *const _; let beta = &cfg.gemm.beta as *const f32 as *const _; cudarc::cublas::result::gemm_strided_batched_ex( *cublas.handle(), cfg.gemm.transa, cfg.gemm.transb, cfg.gemm.m, cfg.gemm.n, cfg.gemm.k, alpha, *a.device_ptr() as *const _, sys::cudaDataType_t::CUDA_R_32F, cfg.gemm.lda, cfg.stride_a, *b.device_ptr() as *const _, sys::cudaDataType_t::CUDA_R_32F, cfg.gemm.ldb, cfg.stride_b, beta, *c.device_ptr_mut() as *mut _, sys::cudaDataType_t::CUDA_R_32F, cfg.gemm.ldc, cfg.stride_c, cfg.batch_size, compute_type, sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP, ) } unsafe fn gemm_strided_batched_f16( cublas: &cudarc::cublas::CudaBlas, cfg: StridedBatchedConfig, a: &cudarc::driver::CudaView, b: &cudarc::driver::CudaView, c: &mut CudaSlice, ) -> std::result::Result<(), cudarc::cublas::result::CublasError> { use cudarc::cublas::sys; use cudarc::driver::DevicePtrMut; let alpha = cfg.gemm.alpha; let beta = cfg.gemm.beta; let alpha_f32: f32 = cfg.gemm.alpha.to_f32(); let beta_f32: f32 = cfg.gemm.beta.to_f32(); let (compute_type, alpha, beta) = if gemm_reduced_precision_f16() { ( sys::cublasComputeType_t::CUBLAS_COMPUTE_16F, (&alpha) as *const f16 as *const _, (&beta) as *const f16 as *const _, ) } else { ( sys::cublasComputeType_t::CUBLAS_COMPUTE_32F, (&alpha_f32) as *const f32 as *const _, (&beta_f32) as *const f32 as *const _, ) }; cudarc::cublas::result::gemm_strided_batched_ex( *cublas.handle(), cfg.gemm.transa, cfg.gemm.transb, cfg.gemm.m, cfg.gemm.n, cfg.gemm.k, alpha, *a.device_ptr() as *const _, sys::cudaDataType_t::CUDA_R_16F, cfg.gemm.lda, cfg.stride_a, *b.device_ptr() as *const _, sys::cudaDataType_t::CUDA_R_16F, cfg.gemm.ldb, cfg.stride_b, beta, *c.device_ptr_mut() as *mut _, sys::cudaDataType_t::CUDA_R_16F, cfg.gemm.ldc, cfg.stride_c, cfg.batch_size, compute_type, sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP, ) } unsafe fn gemm_strided_batched_bf16( cublas: &cudarc::cublas::CudaBlas, cfg: StridedBatchedConfig, a: &cudarc::driver::CudaView, b: &cudarc::driver::CudaView, c: &mut CudaSlice, ) -> std::result::Result<(), cudarc::cublas::result::CublasError> { use cudarc::cublas::sys; use cudarc::driver::DevicePtrMut; let alpha_f32: f32 = cfg.gemm.alpha.to_f32(); let beta_f32: f32 = cfg.gemm.beta.to_f32(); // The type for alpha and beta depends on the computeType. // https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() { ( sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF, (&alpha_f32) as *const f32 as *const _, (&beta_f32) as *const f32 as *const _, ) } else { ( sys::cublasComputeType_t::CUBLAS_COMPUTE_32F, (&alpha_f32) as *const f32 as *const _, (&beta_f32) as *const f32 as *const _, ) }; cudarc::cublas::result::gemm_strided_batched_ex( *cublas.handle(), cfg.gemm.transa, cfg.gemm.transb, cfg.gemm.m, cfg.gemm.n, cfg.gemm.k, alpha, *a.device_ptr() as *const _, sys::cudaDataType_t::CUDA_R_16BF, cfg.gemm.lda, cfg.stride_a, *b.device_ptr() as *const _, sys::cudaDataType_t::CUDA_R_16BF, cfg.gemm.ldb, cfg.stride_b, beta, *c.device_ptr_mut() as *mut _, sys::cudaDataType_t::CUDA_R_16BF, cfg.gemm.ldc, cfg.stride_c, cfg.batch_size, compute_type, sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP, ) }