diff options
-rw-r--r-- | candle-core/src/backend.rs | 13 | ||||
-rw-r--r-- | candle-core/src/cpu_backend.rs | 62 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 61 | ||||
-rw-r--r-- | candle-core/src/dummy_cuda_backend.rs | 13 | ||||
-rw-r--r-- | candle-core/src/dummy_metal_backend.rs | 13 | ||||
-rw-r--r-- | candle-core/src/lib.rs | 1 | ||||
-rw-r--r-- | candle-core/src/metal_backend.rs | 65 | ||||
-rw-r--r-- | candle-core/src/storage.rs | 28 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 148 | ||||
-rw-r--r-- | candle-core/src/tensor_cat.rs | 240 | ||||
-rw-r--r-- | candle-core/tests/conv_tests.rs | 128 | ||||
-rw-r--r-- | candle-core/tests/grad_tests.rs | 18 | ||||
-rw-r--r-- | candle-core/tests/pool_tests.rs | 9 | ||||
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 25 | ||||
-rw-r--r-- | candle-kernels/src/fill.cu | 30 | ||||
-rw-r--r-- | candle-metal-kernels/src/affine.metal | 2 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 50 | ||||
-rw-r--r-- | candle-metal-kernels/src/unary.metal | 27 | ||||
-rw-r--r-- | candle-nn/examples/cpu_benchmarks.rs | 19 |
19 files changed, 744 insertions, 208 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 2125af69..ea1ac1a9 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -98,6 +98,19 @@ pub trait BackendStorage: Sized { ) -> Result<Self>; fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>; + + #[allow(clippy::too_many_arguments)] + // Similar to cudaMemcpy2D, though values are in elements and not in bytes. + fn copy2d( + &self, + _: &mut Self, + _d1: usize, + _d2: usize, + _src_stride1: usize, + _dst_stride1: usize, + _src_offset: usize, + _dst_offset: usize, + ) -> Result<()>; } pub trait BackendDevice: Sized + std::fmt::Debug + Clone { diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 181fbb61..1504d5b8 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1023,6 +1023,26 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> { } } +#[allow(clippy::too_many_arguments)] +fn copy2d_<T: Copy>( + src: &[T], + dst: &mut [T], + d1: usize, + d2: usize, + src_stride1: usize, + dst_stride1: usize, + src_offset: usize, + dst_offset: usize, +) { + for i1 in 0..d1 { + let dst_idx = i1 * dst_stride1 + dst_offset; + let src_idx = i1 * src_stride1 + src_offset; + let dst = &mut dst[dst_idx..dst_idx + d2]; + let src = &src[src_idx..src_idx + d2]; + dst.copy_from_slice(src) + } +} + fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) { match src_l.strided_blocks() { crate::StridedBlocks::SingleBlock { start_offset, len } => { @@ -2452,6 +2472,48 @@ impl BackendStorage for CpuStorage { } } + fn copy2d( + &self, + dst: &mut Self, + d1: usize, + d2: usize, + src_s: usize, + dst_s: usize, + src_o: usize, + dst_o: usize, + ) -> Result<()> { + match (self, dst) { + (Self::U8(src), Self::U8(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o), + (Self::U32(src), Self::U32(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::I64(src), Self::I64(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::BF16(src), Self::BF16(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::F16(src), Self::F16(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::F32(src), Self::F32(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (Self::F64(src), Self::F64(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } + (_, dst) => { + return Err(Error::DTypeMismatchBinaryOp { + lhs: self.dtype(), + rhs: dst.dtype(), + op: "copy2d", + } + .bt()); + } + } + Ok(()) + } + fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { match (self, dst) { (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index b7756fa6..52d1b558 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -2145,6 +2145,67 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } + fn copy2d( + &self, + dst: &mut Self, + d1: usize, + d2: usize, + src_s: usize, + dst_s: usize, + src_o: usize, + dst_o: usize, + ) -> Result<()> { + let dev = &self.device; + let d1 = d1 as u32; + let d2 = d2 as u32; + let dst_s = dst_s as u32; + let src_s = src_s as u32; + let (src, dst, kname) = match (&self.slice, &mut dst.slice) { + (S::U8(s), S::U8(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_u8", + ), + (S::U32(s), S::U32(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_u32", + ), + (S::I64(s), S::I64(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_i64", + ), + (S::BF16(s), S::BF16(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_bf16", + ), + (S::F16(s), S::F16(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_f16", + ), + (S::F32(s), S::F32(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_f32", + ), + (S::F64(s), S::F64(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_f64", + ), + _ => Err(CudaError::InternalError("dtype mismatch in copy2d"))?, + }; + let func = dev.get_or_load_func(kname, kernels::FILL)?; + let cfg = LaunchConfig::for_num_elems(d1 * d2); + let params = (src, dst, d1, d2, src_s, dst_s); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(()) + } + fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { let src_shape = src_l.shape(); let dims = src_shape.dims(); diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 34c5d97f..43d55fa4 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -154,6 +154,19 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + fn copy2d( + &self, + _: &mut Self, + _: usize, + _: usize, + _: usize, + _: usize, + _: usize, + _: usize, + ) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index e9d92331..791ec153 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -166,6 +166,19 @@ impl crate::backend::BackendStorage for MetalStorage { Err(Error::NotCompiledWithMetalSupport) } + fn copy2d( + &self, + _: &mut Self, + _: usize, + _: usize, + _: usize, + _: usize, + _: usize, + _: usize, + ) -> Result<()> { + Err(Error::NotCompiledWithMetalSupport) + } + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> { Err(Error::NotCompiledWithMetalSupport) } diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index fcc17afc..31ef1169 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -67,6 +67,7 @@ pub mod shape; mod storage; mod strided_index; mod tensor; +mod tensor_cat; pub mod test_utils; pub mod utils; mod variable; diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index a17b87b8..2e07cce5 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -422,6 +422,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "powf_f32", DType::F16 => "powf_f16", + DType::BF16 => "powf_bf16", dtype => crate::bail!("Metal contiguous powf {dtype:?} not implemented"), }; candle_metal_kernels::call_powf( @@ -439,6 +440,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "powf_f32_strided", DType::F16 => "powf_f16_strided", + DType::BF16 => "powf_bf16_strided", dtype => crate::bail!("Metal strided powf {dtype:?} not implemented"), }; candle_metal_kernels::call_powf_strided( @@ -471,6 +473,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "elu_f32", DType::F16 => "elu_f16", + DType::BF16 => "elu_bf16", dtype => crate::bail!("Metal contiguous elu {dtype:?} not implemented"), }; candle_metal_kernels::call_elu( @@ -488,6 +491,7 @@ impl BackendStorage for MetalStorage { let name = match self.dtype { DType::F32 => "elu_f32_strided", DType::F16 => "elu_f16_strided", + DType::BF16 => "elu_bf16_strided", dtype => crate::bail!("Metal strided elu {dtype:?} not implemented"), }; candle_metal_kernels::call_elu_strided( @@ -1292,6 +1296,67 @@ impl BackendStorage for MetalStorage { )) } + fn copy2d( + &self, + dst: &mut Self, + d1: usize, + d2: usize, + src_s: usize, + dst_s: usize, + src_o: usize, + dst_o: usize, + ) -> Result<()> { + if self.dtype() != dst.dtype() { + crate::bail!( + "copy2d with inconsistent dtypes {:?} {:?}", + self.dtype(), + dst.dtype() + ) + } + let command_buffer = self.device.command_buffer()?; + if src_s == d2 && dst_s == d2 { + command_buffer.set_label("copy2d_contiguous"); + let blit = command_buffer.new_blit_command_encoder(); + blit.set_label("copy2d_contiguous"); + let src_offset = (src_o * self.dtype.size_in_bytes()) as NSUInteger; + let length = (d1 * d2 * self.dtype.size_in_bytes()) as NSUInteger; + let dst_offset = (dst_o * dst.dtype().size_in_bytes()) as NSUInteger; + blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length); + blit.end_encoding(); + } else { + let el_count = d1 * d2; + if el_count == 0 { + return Ok(()); + } + let kernel_name = match self.dtype { + DType::F32 => candle_metal_kernels::copy2d::FLOAT, + DType::F16 => candle_metal_kernels::copy2d::HALF, + DType::BF16 => candle_metal_kernels::copy2d::BFLOAT, + DType::I64 => candle_metal_kernels::copy2d::I64, + DType::U32 => candle_metal_kernels::copy2d::U32, + DType::U8 => candle_metal_kernels::copy2d::U8, + dtype => crate::bail!("Metal copy2d {dtype:?} not implemented"), + }; + candle_metal_kernels::call_copy2d( + &self.device.device, + &command_buffer, + &self.device.kernels, + kernel_name, + &self.buffer, + &dst.buffer, + d1, + d2, + src_s, + dst_s, + src_o * self.dtype.size_in_bytes(), + dst_o * self.dtype.size_in_bytes(), + ) + .map_err(MetalError::from)?; + command_buffer.set_label("copy2d"); + } + Ok(()) + } + fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { let command_buffer = self.device.command_buffer()?; if src_l.is_contiguous() && self.dtype == dst.dtype() { diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 65bcc6aa..3bd4b022 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -701,4 +701,32 @@ impl Storage { .bt()), } } + + #[allow(clippy::too_many_arguments)] + pub(crate) fn copy2d( + &self, + dst: &mut Self, + d1: usize, + d2: usize, + src_s: usize, + dst_s: usize, + src_o: usize, + dst_o: usize, + ) -> Result<()> { + match (self, dst) { + (Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o), + (Self::Cuda(src), Self::Cuda(dst)) => { + Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?) + } + (Self::Metal(src), Self::Metal(dst)) => { + Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?) + } + (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "copy2d", + } + .bt()), + } + } } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 0e2c3e8f..22cd4950 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -666,7 +666,7 @@ impl Tensor { Ok(from_storage(storage, self.shape(), op, false)) } - fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> { + pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> { if dim >= self.dims().len() { Err(Error::DimOutOfRange { shape: self.shape().clone(), @@ -2149,152 +2149,6 @@ impl Tensor { Self::cat(&args, dim) } - /// Concatenates two or more tensors along a particular dimension. - /// - /// All tensors must of the same rank, and the output will have - /// the same rank - /// - /// ```rust - /// # use candle_core::{Tensor, DType, Device}; - /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; - /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; - /// - /// let c = Tensor::cat(&[&a, &b], 0)?; - /// assert_eq!(c.shape().dims(), &[4, 3]); - /// - /// let c = Tensor::cat(&[&a, &b], 1)?; - /// assert_eq!(c.shape().dims(), &[2, 6]); - /// # Ok::<(), candle_core::Error>(()) - /// ``` - pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> { - if args.is_empty() { - Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())? - } - let arg0 = args[0].as_ref(); - if args.len() == 1 { - return Ok(arg0.clone()); - } - let dim = dim.to_index(arg0.shape(), "cat")?; - for arg in args { - arg.as_ref().check_dim(dim, "cat")?; - } - for (arg_idx, arg) in args.iter().enumerate() { - let arg = arg.as_ref(); - if arg0.rank() != arg.rank() { - Err(Error::UnexpectedNumberOfDims { - expected: arg0.rank(), - got: arg.rank(), - shape: arg.shape().clone(), - } - .bt())? - } - for (dim_idx, (v1, v2)) in arg0 - .shape() - .dims() - .iter() - .zip(arg.shape().dims().iter()) - .enumerate() - { - if dim_idx != dim && v1 != v2 { - Err(Error::ShapeMismatchCat { - dim: dim_idx, - first_shape: arg0.shape().clone(), - n: arg_idx + 1, - nth_shape: arg.shape().clone(), - } - .bt())? - } - } - } - if dim == 0 { - Self::cat0(args) - } else { - // TODO: Avoid these transpositions and have an implementation that works - // for dim != 0... - let args: Vec<Tensor> = args - .iter() - .map(|a| a.as_ref().transpose(0, dim)) - .collect::<Result<Vec<_>>>()?; - let cat = Self::cat0(&args)?; - cat.transpose(0, dim) - } - } - - fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> { - if args.is_empty() { - Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())? - } - let arg0 = args[0].as_ref(); - if args.len() == 1 { - return Ok(arg0.clone()); - } - let rank = arg0.rank(); - let device = arg0.device(); - let dtype = arg0.dtype(); - let first_dims = arg0.shape().dims(); - let mut cat_dims = first_dims.to_vec(); - cat_dims[0] = 0; - let mut offsets = vec![0usize]; - for (arg_idx, arg) in args.iter().enumerate() { - let arg = arg.as_ref(); - if arg.dtype() != dtype { - Err(Error::DTypeMismatchBinaryOp { - lhs: dtype, - rhs: arg.dtype(), - op: "cat", - } - .bt())? - } - if arg.device().location() != device.location() { - Err(Error::DeviceMismatchBinaryOp { - lhs: device.location(), - rhs: arg.device().location(), - op: "cat", - } - .bt())? - } - if rank != arg.rank() { - Err(Error::UnexpectedNumberOfDims { - expected: rank, - got: arg.rank(), - shape: arg.shape().clone(), - } - .bt())? - } - for (dim_idx, (v1, v2)) in arg0 - .shape() - .dims() - .iter() - .zip(arg.shape().dims().iter()) - .enumerate() - { - if dim_idx == 0 { - cat_dims[0] += v2; - } - if dim_idx != 0 && v1 != v2 { - Err(Error::ShapeMismatchCat { - dim: dim_idx, - first_shape: arg0.shape().clone(), - n: arg_idx + 1, - nth_shape: arg.shape().clone(), - } - .bt())? - } - } - let next_offset = offsets.last().unwrap() + arg.elem_count(); - offsets.push(next_offset); - } - let shape = Shape::from(cat_dims); - let op = BackpropOp::new(args, |args| Op::Cat(args, 0)); - let mut storage = device.zeros(&shape, dtype)?; - for (arg, &offset) in args.iter().zip(offsets.iter()) { - let arg = arg.as_ref(); - arg.storage() - .copy_strided_src(&mut storage, offset, arg.layout())?; - } - Ok(from_storage(storage, shape, op, false)) - } - /// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the /// input tensor values and `right` elements after. pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> { diff --git a/candle-core/src/tensor_cat.rs b/candle-core/src/tensor_cat.rs new file mode 100644 index 00000000..25acc80e --- /dev/null +++ b/candle-core/src/tensor_cat.rs @@ -0,0 +1,240 @@ +use crate::{shape::Dim, Error, Result, Shape, Tensor}; + +impl Tensor { + /// Concatenates two or more tensors along a particular dimension. + /// + /// All tensors must of the same rank, and the output will have + /// the same rank + /// + /// ```rust + /// # use candle_core::{Tensor, DType, Device}; + /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; + /// + /// let c = Tensor::cat(&[&a, &b], 0)?; + /// assert_eq!(c.shape().dims(), &[4, 3]); + /// + /// let c = Tensor::cat(&[&a, &b], 1)?; + /// assert_eq!(c.shape().dims(), &[2, 6]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> { + if args.is_empty() { + Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())? + } + let arg0 = args[0].as_ref(); + if args.len() == 1 { + return Ok(arg0.clone()); + } + let dim = dim.to_index(arg0.shape(), "cat")?; + for arg in args { + arg.as_ref().check_dim(dim, "cat")?; + } + for (arg_idx, arg) in args.iter().enumerate() { + let arg = arg.as_ref(); + if arg0.rank() != arg.rank() { + Err(Error::UnexpectedNumberOfDims { + expected: arg0.rank(), + got: arg.rank(), + shape: arg.shape().clone(), + } + .bt())? + } + for (dim_idx, (v1, v2)) in arg0 + .shape() + .dims() + .iter() + .zip(arg.shape().dims().iter()) + .enumerate() + { + if dim_idx != dim && v1 != v2 { + Err(Error::ShapeMismatchCat { + dim: dim_idx, + first_shape: arg0.shape().clone(), + n: arg_idx + 1, + nth_shape: arg.shape().clone(), + } + .bt())? + } + } + } + if dim == 0 { + Self::cat0(args) + } else { + let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous()); + if all_contiguous { + Self::cat_contiguous(args, dim) + } else { + let args: Vec<Tensor> = args + .iter() + .map(|a| a.as_ref().transpose(0, dim)) + .collect::<Result<Vec<_>>>()?; + let cat = Self::cat0(&args)?; + cat.transpose(0, dim) + } + } + } + + fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> { + if args.is_empty() { + Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())? + } + let arg0 = args[0].as_ref(); + if args.len() == 1 { + return Ok(arg0.clone()); + } + let rank = arg0.rank(); + let device = arg0.device(); + let dtype = arg0.dtype(); + let first_dims = arg0.shape().dims(); + let mut cat_dims = first_dims.to_vec(); + cat_dims[0] = 0; + let mut offsets = vec![0usize]; + for (arg_idx, arg) in args.iter().enumerate() { + let arg = arg.as_ref(); + if arg.dtype() != dtype { + Err(Error::DTypeMismatchBinaryOp { + lhs: dtype, + rhs: arg.dtype(), + op: "cat", + } + .bt())? + } + if arg.device().location() != device.location() { + Err(Error::DeviceMismatchBinaryOp { + lhs: device.location(), + rhs: arg.device().location(), + op: "cat", + } + .bt())? + } + if rank != arg.rank() { + Err(Error::UnexpectedNumberOfDims { + expected: rank, + got: arg.rank(), + shape: arg.shape().clone(), + } + .bt())? + } + for (dim_idx, (v1, v2)) in arg0 + .shape() + .dims() + .iter() + .zip(arg.shape().dims().iter()) + .enumerate() + { + if dim_idx == 0 { + cat_dims[0] += v2; + } + if dim_idx != 0 && v1 != v2 { + Err(Error::ShapeMismatchCat { + dim: dim_idx, + first_shape: arg0.shape().clone(), + n: arg_idx + 1, + nth_shape: arg.shape().clone(), + } + .bt())? + } + } + let next_offset = offsets.last().unwrap() + arg.elem_count(); + offsets.push(next_offset); + } + let shape = Shape::from(cat_dims); + let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0)); + let mut storage = device.zeros(&shape, dtype)?; + for (arg, &offset) in args.iter().zip(offsets.iter()) { + let arg = arg.as_ref(); + arg.storage() + .copy_strided_src(&mut storage, offset, arg.layout())?; + } + Ok(crate::tensor::from_storage(storage, shape, op, false)) + } + + fn cat_contiguous<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> { + if args.is_empty() { + Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())? + } + let arg0 = args[0].as_ref(); + if args.len() == 1 { + return Ok(arg0.clone()); + } + let rank = arg0.rank(); + let device = arg0.device(); + let dtype = arg0.dtype(); + let first_dims = arg0.shape().dims(); + let mut cat_dims = first_dims.to_vec(); + cat_dims[dim] = 0; + for (arg_idx, arg) in args.iter().enumerate() { + let arg = arg.as_ref(); + if arg.dtype() != dtype { + Err(Error::DTypeMismatchBinaryOp { + lhs: dtype, + rhs: arg.dtype(), + op: "cat", + } + .bt())? + } + if arg.device().location() != device.location() { + Err(Error::DeviceMismatchBinaryOp { + lhs: device.location(), + rhs: arg.device().location(), + op: "cat", + } + .bt())? + } + if rank != arg.rank() { + Err(Error::UnexpectedNumberOfDims { + expected: rank, + got: arg.rank(), + shape: arg.shape().clone(), + } + .bt())? + } + for (dim_idx, (v1, v2)) in arg0 + .shape() + .dims() + .iter() + .zip(arg.shape().dims().iter()) + .enumerate() + { + if dim_idx == dim { + cat_dims[dim] += v2; + } + if dim_idx != dim && v1 != v2 { + Err(Error::ShapeMismatchCat { + dim: dim_idx, + first_shape: arg0.shape().clone(), + n: arg_idx + 1, + nth_shape: arg.shape().clone(), + } + .bt())? + } + } + } + let cat_target_dim_len = cat_dims[dim]; + let block_size: usize = cat_dims.iter().skip(1 + dim).product(); + let shape = Shape::from(cat_dims); + let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim)); + let mut storage = device.zeros(&shape, dtype)?; + let mut dst_o = 0; + for arg in args.iter() { + let arg = arg.as_ref(); + let arg_dims = arg.shape().dims(); + let d1: usize = arg_dims.iter().take(dim).product(); + let d2 = block_size * arg_dims[dim]; + let dst_s = block_size * cat_target_dim_len; + let src_o = arg.layout().start_offset(); + arg.storage().copy2d( + &mut storage, + d1, + d2, + /* src_s */ d2, + dst_s, + src_o, + dst_o, + )?; + dst_o += d2; + } + Ok(crate::tensor::from_storage(storage, shape, op, false)) + } +} diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index f0f1b7f2..ba60b778 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -53,6 +53,12 @@ fn conv1d(dev: &Device) -> Result<()> { test_utils::to_vec1_round(&res.flatten_all()?, 4)?, [2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352] ); + + // conv-transposes are not implemented for metal. + if dev.is_metal() { + return Ok(()); + } + let w = w.transpose(0, 1)?; // The CPU kernels applied in the contiguous and non contiguous cases are different. for w in [w.clone(), w.contiguous()?] { @@ -162,31 +168,33 @@ fn conv2d(dev: &Device) -> Result<()> { 10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075 ] ); - let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; - assert_eq!(res.dims(), [1, 2, 7, 7]); - assert_eq!( - test_utils::to_vec3_round(&res.i(0)?, 4)?, - [ - [ - [-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277], - [1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375], - [0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889], - [0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632], - [-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985], - [2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114], - [5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579] - ], + if !dev.is_metal() { + let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; + assert_eq!(res.dims(), [1, 2, 7, 7]); + assert_eq!( + test_utils::to_vec3_round(&res.i(0)?, 4)?, [ - [1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211], - [-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131], - [1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621], - [-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142], - [7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059], - [-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516], - [-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171] + [ + [-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277], + [1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375], + [0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889], + [0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632], + [-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985], + [2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114], + [5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579] + ], + [ + [1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211], + [-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131], + [1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621], + [-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142], + [7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059], + [-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516], + [-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171] + ] ] - ] - ); + ); + } // Dilations. let res = t.conv2d(&w, 0, 1, 2, 1)?; assert_eq!(res.dims(), [1, 2, 1, 1]); @@ -195,36 +203,44 @@ fn conv2d(dev: &Device) -> Result<()> { [2.45, -2.3504], ); - // Transpose and dilations. - let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?; - assert_eq!(res.dims(), [1, 2, 9, 9]); - assert_eq!( - test_utils::to_vec3_round(&res.i(0)?, 4)?, - [ - [ - [-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277], - [2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499], - [-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376], - [-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141], - [-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822], - [0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03], - [-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024], - [4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787], - [5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579] - ], + if !dev.is_metal() { + // Transpose and dilations. + let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?; + assert_eq!(res.dims(), [1, 2, 9, 9]); + assert_eq!( + test_utils::to_vec3_round(&res.i(0)?, 4)?, [ - [1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211], - [-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278], - [1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861], - [1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185], - [1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642], - [3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957], - [5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856], - [-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908], - [-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171] + [ + [-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277], + [2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499], + [-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376], + [-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141], + [-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822], + [0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03], + [ + -2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, + -3.5024 + ], + [4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787], + [5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579] + ], + [ + [1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211], + [-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278], + [1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861], + [1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185], + [1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642], + [3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957], + [5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856], + [-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908], + [ + -5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, + 1.0171 + ] + ] ] - ] - ); + ); + } Ok(()) } @@ -278,6 +294,12 @@ fn conv2d_small(dev: &Device) -> Result<()> { 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000 ] ); + + // conv-transposes are not implemented for metal + if dev.is_metal() { + return Ok(()); + } + let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; assert_eq!(res.dims(), [1, 1, 3, 3]); assert_eq!( @@ -379,6 +401,10 @@ print(w.grad.shape) print(w.grad[0]) */ fn conv2d_grad(dev: &Device) -> Result<()> { + // conv-transposes are not implemented for metal + if dev.is_metal() { + return Ok(()); + } use candle_core::Var; let t = Var::from_slice( &[ diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs index a4d81618..b8b6be8d 100644 --- a/candle-core/tests/grad_tests.rs +++ b/candle-core/tests/grad_tests.rs @@ -1,3 +1,4 @@ +#![allow(clippy::approx_constant)] use anyhow::{Context, Result}; use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var}; @@ -96,24 +97,24 @@ fn unary_grad(device: &Device) -> Result<()> { let grads = y.backward()?; let grad_x = grads.get(x).context("no grad for x")?; assert_eq!( - y.to_vec1::<f32>()?, - [20.085537, 2.7182817, 54.59815, 1.1618342] + test_utils::to_vec1_round(&y, 4)?, + [20.0855, 2.7183, 54.5982, 1.1618] ); assert_eq!( - grad_x.to_vec1::<f32>()?, - [20.085537, 2.7182817, 54.59815, 1.1618342] + test_utils::to_vec1_round(grad_x, 4)?, + [20.0855, 2.7183, 54.5982, 1.1618] ); let y = x.exp()?.sqr()?; let grads = y.backward()?; let grad_x = grads.get(x).context("no grad for x")?; assert_eq!( - y.to_vec1::<f32>()?, - [403.4288, 7.3890557, 2980.9578, 1.3498588] + test_utils::to_vec1_round(&y, 3)?, + [403.429, 7.389, 2980.958, 1.35] ); // exp(x)^2 = exp(2*x) assert_eq!( - grad_x.to_vec1::<f32>()?, - [806.8576, 14.778111, 5961.9155, 2.6997175] + test_utils::to_vec1_round(grad_x, 2)?, + [806.86, 14.78, 5961.92, 2.7] ); let y = x.sin()?; let grads = y.backward()?; @@ -261,6 +262,7 @@ fn unary_grad(device: &Device) -> Result<()> { let y = elu_x.elu(2.)?; let grads = y.backward()?; let grad_x = grads.get(&elu_x).context("no grad for x")?; + assert_eq!( test_utils::to_vec1_round(&y, 4)?, [-1.2642, 0.0000, -1.7293, 3.0000] diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs index a3708ec4..a6530e03 100644 --- a/candle-core/tests/pool_tests.rs +++ b/candle-core/tests/pool_tests.rs @@ -2,6 +2,9 @@ use candle_core::{test_device, test_utils, Device, IndexOp, Result, Tensor}; // https://github.com/huggingface/candle/issues/364 fn avg_pool2d(dev: &Device) -> Result<()> { + if dev.is_metal() { + return Ok(()); + } let data: Vec<f32> = vec![ 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., ]; @@ -19,6 +22,9 @@ fn avg_pool2d(dev: &Device) -> Result<()> { } fn max_pool2d(dev: &Device) -> Result<()> { + if dev.is_metal() { + return Ok(()); + } let data: Vec<f32> = vec![ 1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1., ]; @@ -43,6 +49,9 @@ res = torch.nn.functional.avg_pool2d(t, 2) print(res) */ fn avg_pool2d_pytorch(dev: &Device) -> Result<()> { + if dev.is_metal() { + return Ok(()); + } let t = Tensor::new( &[ 0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616, diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 31a27422..b2475adc 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -672,6 +672,31 @@ fn cat(device: &Device) -> Result<()> { [2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0] ] ); + + // 3D + let t1 = Tensor::arange(0, 48i64, device)?.reshape((2, 6, 4))?; + let t2 = Tensor::arange(100, 124i64, device)?.reshape((2, 3, 4))?; + let t3 = Tensor::arange(10000, 10032i64, device)?.reshape((2, 4, 4))?; + + let t_cat = Tensor::cat(&[&t1, &t2, &t3], 1)?; + + let t1 = t1.t()?.contiguous()?.t()?; + let t2 = t2.t()?.contiguous()?.t()?; + let t3 = t3.t()?.contiguous()?.t()?; + let t_cat2 = Tensor::cat(&[&t1, &t2, &t3], 1)?; + + let diff = t_cat.eq(&t_cat2)?.to_dtype(DType::F32)?.sum_all()?; + assert_eq!(diff.to_vec0::<f32>()?, 104.0); + assert_eq!(t_cat.i((0, 0, 0))?.to_vec0::<i64>()?, 0); + assert_eq!(t_cat.i((0, 4, 0))?.to_vec0::<i64>()?, 16); + assert_eq!(t_cat.i((0, 5, 0))?.to_vec0::<i64>()?, 20); + assert_eq!(t_cat.i((1, 5, 0))?.to_vec0::<i64>()?, 44); + assert_eq!(t_cat.i((0, 6, 0))?.to_vec0::<i64>()?, 100); + assert_eq!(t_cat.i((1, 6, 0))?.to_vec0::<i64>()?, 112); + assert_eq!(t_cat.i((0, 6, 1))?.to_vec0::<i64>()?, 101); + assert_eq!(t_cat.i((0, 7, 1))?.to_vec0::<i64>()?, 105); + assert_eq!(t_cat.i((0, 12, 1))?.to_vec0::<i64>()?, 10013); + assert_eq!(t_cat.i((1, 12, 3))?.to_vec0::<i64>()?, 10031); Ok(()) } diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu index 883ca072..ca448d98 100644 --- a/candle-kernels/src/fill.cu +++ b/candle-kernels/src/fill.cu @@ -10,11 +10,39 @@ __device__ void fill_with(T *buf, T value, const size_t numel) { extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); } -extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); } +template<typename T> +__device__ void copy2d(const T *src, T *dst, uint32_t d1, uint32_t d2, uint32_t src_s, uint32_t dst_s) { + uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= d1 * d2) { + return; + } + uint32_t idx1 = idx / d2; + uint32_t idx2 = idx - d2 * idx1; + dst[idx1 * dst_s + idx2] = src[idx1 * src_s + idx2]; +} + +#define COPY2D_OP(TYPENAME, FNNAME) \ +extern "C" __global__ \ +void FNNAME(const TYPENAME *src, TYPENAME *dst, uint32_t d1, uint32_t d2, uint32_t src_s, uint32_t dst_s) { \ + copy2d(src, dst, d1, d2, src_s, dst_s); \ +} \ + +COPY2D_OP(float, copy2d_f32) +COPY2D_OP(double, copy2d_f64) +COPY2D_OP(uint8_t, copy2d_u8) +COPY2D_OP(uint32_t, copy2d_u32) +COPY2D_OP(int64_t, copy2d_i64) + +#if __CUDA_ARCH__ >= 530 +extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); } +COPY2D_OP(__half, copy2d_f16) +#endif + #if __CUDA_ARCH__ >= 800 #include <cuda_bf16.h> extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); } +COPY2D_OP(__nv_bfloat16, copy2d_bf16) #endif diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index a4484998..76c0365a 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -89,7 +89,7 @@ kernel void FN_NAME( \ return; \ } \ const TYPENAME x = input[id]; \ - output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \ + output[id] = TYPENAME((x > 0)?x: mul * (exp(x) - 1)); \ } \ kernel void FN_NAME##_strided( \ constant size_t &dim, \ diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 47ce7e96..a879c86a 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -127,6 +127,16 @@ pub enum Source { Quantized, } +pub mod copy2d { + pub struct Kernel(pub &'static str); + pub const FLOAT: Kernel = Kernel("copy2d_f32"); + pub const HALF: Kernel = Kernel("copy2d_f16"); + pub const BFLOAT: Kernel = Kernel("copy2d_bf16"); + pub const I64: Kernel = Kernel("copy2d_i64"); + pub const U32: Kernel = Kernel("copy2d_u32"); + pub const U8: Kernel = Kernel("copy2d_u8"); +} + macro_rules! ops{ ($($name:ident),+) => { @@ -366,6 +376,46 @@ pub fn call_unary_contiguous( } #[allow(clippy::too_many_arguments)] +pub fn call_copy2d( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: copy2d::Kernel, + input: &Buffer, + output: &Buffer, + d1: usize, + d2: usize, + src_s: usize, + dst_s: usize, + src_o_in_bytes: usize, + dst_o_in_bytes: usize, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + d1, + d2, + src_s, + dst_s, + (input, src_o_in_bytes), + (output, dst_o_in_bytes) + ) + ); + + let width: usize = d1 * d2; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] pub fn call_unary_strided( device: &Device, command_buffer: &CommandBufferRef, diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 1e0d5526..bdc13f9e 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -102,6 +102,30 @@ UNARY(NAME, half, NAME##_f16, NAME##_f16_strided); #define BFLOAT_UNARY_OP(NAME) \ UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided); +#define COPY2D(FN_NAME, TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &d1, \ + constant size_t &d2, \ + constant size_t &src_s, \ + constant size_t &dst_s, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= d1 * d2) { \ + return; \ + } \ + size_t idx1 = tid / d2; \ + size_t idx2 = tid - idx1 * d2; \ + size_t src_idx = idx1 * src_s + idx2; \ + size_t dst_idx = idx1 * dst_s + idx2; \ + output[dst_idx] = input[src_idx]; \ +} + +COPY2D(copy2d_f32, float) +COPY2D(copy2d_f16, half) +COPY2D(copy2d_u8, uint8_t) +COPY2D(copy2d_u32, uint32_t) UNARY_OP(cos) UNARY_OP(sin) @@ -128,6 +152,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided) #if __METAL_VERSION__ >= 220 UNARY(id, int64_t, copy_i64, copy_i64_strided) +COPY2D(copy2d_i64, int64_t) #endif #if defined(__HAVE_BFLOAT__) @@ -151,4 +176,6 @@ BFLOAT_UNARY_OP(recip) BFLOAT_UNARY_OP(relu) UNARY(id, bfloat, copy_bf16, copy_bf16_strided) + +COPY2D(copy2d_bf64, bfloat) #endif diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs index 001be116..430316b8 100644 --- a/candle-nn/examples/cpu_benchmarks.rs +++ b/candle-nn/examples/cpu_benchmarks.rs @@ -238,6 +238,23 @@ impl Benchmark for QMatMul { const ITERS: usize = 100; } +struct Cat; +impl Benchmark for Cat { + type PreProcessData = (Tensor, Tensor); + type RunResult = Tensor; + fn preprocess() -> Result<Self::PreProcessData> { + let lhs = Tensor::randn(0f32, 1., (1, 32, 2000, 128), &Device::Cpu)?; + let rhs = Tensor::randn(0f32, 1., (1, 32, 1, 128), &Device::Cpu)?; + Ok((lhs, rhs)) + } + + fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> { + Tensor::cat(&[&d.0, &d.1], 2) + } + + const ITERS: usize = 1000; +} + struct Softmax; impl Benchmark for Softmax { type PreProcessData = Tensor; @@ -295,6 +312,7 @@ enum Task { Qmatmul, Softmax, SoftmaxLastDim, + Cat, } #[derive(Parser, Debug)] @@ -319,6 +337,7 @@ fn main() -> Result<()> { Task::Softmax => run::<Softmax>(args.iters)?, Task::SoftmaxLastDim => run::<SoftmaxLastDim>(args.iters)?, Task::Qmatmul => run::<QMatMul>(args.iters)?, + Task::Cat => run::<Cat>(args.iters)?, } Ok(()) } |