summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/backend.rs13
-rw-r--r--candle-core/src/cpu_backend.rs62
-rw-r--r--candle-core/src/cuda_backend.rs61
-rw-r--r--candle-core/src/dummy_cuda_backend.rs13
-rw-r--r--candle-core/src/dummy_metal_backend.rs13
-rw-r--r--candle-core/src/lib.rs1
-rw-r--r--candle-core/src/metal_backend.rs65
-rw-r--r--candle-core/src/storage.rs28
-rw-r--r--candle-core/src/tensor.rs148
-rw-r--r--candle-core/src/tensor_cat.rs240
-rw-r--r--candle-core/tests/conv_tests.rs128
-rw-r--r--candle-core/tests/grad_tests.rs18
-rw-r--r--candle-core/tests/pool_tests.rs9
-rw-r--r--candle-core/tests/tensor_tests.rs25
-rw-r--r--candle-kernels/src/fill.cu30
-rw-r--r--candle-metal-kernels/src/affine.metal2
-rw-r--r--candle-metal-kernels/src/lib.rs50
-rw-r--r--candle-metal-kernels/src/unary.metal27
-rw-r--r--candle-nn/examples/cpu_benchmarks.rs19
19 files changed, 744 insertions, 208 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index 2125af69..ea1ac1a9 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -98,6 +98,19 @@ pub trait BackendStorage: Sized {
) -> Result<Self>;
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
+
+ #[allow(clippy::too_many_arguments)]
+ // Similar to cudaMemcpy2D, though values are in elements and not in bytes.
+ fn copy2d(
+ &self,
+ _: &mut Self,
+ _d1: usize,
+ _d2: usize,
+ _src_stride1: usize,
+ _dst_stride1: usize,
+ _src_offset: usize,
+ _dst_offset: usize,
+ ) -> Result<()>;
}
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 181fbb61..1504d5b8 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -1023,6 +1023,26 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
}
}
+#[allow(clippy::too_many_arguments)]
+fn copy2d_<T: Copy>(
+ src: &[T],
+ dst: &mut [T],
+ d1: usize,
+ d2: usize,
+ src_stride1: usize,
+ dst_stride1: usize,
+ src_offset: usize,
+ dst_offset: usize,
+) {
+ for i1 in 0..d1 {
+ let dst_idx = i1 * dst_stride1 + dst_offset;
+ let src_idx = i1 * src_stride1 + src_offset;
+ let dst = &mut dst[dst_idx..dst_idx + d2];
+ let src = &src[src_idx..src_idx + d2];
+ dst.copy_from_slice(src)
+ }
+}
+
fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
match src_l.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => {
@@ -2452,6 +2472,48 @@ impl BackendStorage for CpuStorage {
}
}
+ fn copy2d(
+ &self,
+ dst: &mut Self,
+ d1: usize,
+ d2: usize,
+ src_s: usize,
+ dst_s: usize,
+ src_o: usize,
+ dst_o: usize,
+ ) -> Result<()> {
+ match (self, dst) {
+ (Self::U8(src), Self::U8(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
+ (Self::U32(src), Self::U32(dst)) => {
+ copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
+ }
+ (Self::I64(src), Self::I64(dst)) => {
+ copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
+ }
+ (Self::BF16(src), Self::BF16(dst)) => {
+ copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
+ }
+ (Self::F16(src), Self::F16(dst)) => {
+ copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
+ }
+ (Self::F32(src), Self::F32(dst)) => {
+ copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
+ }
+ (Self::F64(src), Self::F64(dst)) => {
+ copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
+ }
+ (_, dst) => {
+ return Err(Error::DTypeMismatchBinaryOp {
+ lhs: self.dtype(),
+ rhs: dst.dtype(),
+ op: "copy2d",
+ }
+ .bt());
+ }
+ }
+ Ok(())
+ }
+
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
match (self, dst) {
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index b7756fa6..52d1b558 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -2145,6 +2145,67 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
+ fn copy2d(
+ &self,
+ dst: &mut Self,
+ d1: usize,
+ d2: usize,
+ src_s: usize,
+ dst_s: usize,
+ src_o: usize,
+ dst_o: usize,
+ ) -> Result<()> {
+ let dev = &self.device;
+ let d1 = d1 as u32;
+ let d2 = d2 as u32;
+ let dst_s = dst_s as u32;
+ let src_s = src_s as u32;
+ let (src, dst, kname) = match (&self.slice, &mut dst.slice) {
+ (S::U8(s), S::U8(d)) => (
+ *s.slice(src_o..).device_ptr(),
+ *d.slice(dst_o..).device_ptr(),
+ "copy2d_u8",
+ ),
+ (S::U32(s), S::U32(d)) => (
+ *s.slice(src_o..).device_ptr(),
+ *d.slice(dst_o..).device_ptr(),
+ "copy2d_u32",
+ ),
+ (S::I64(s), S::I64(d)) => (
+ *s.slice(src_o..).device_ptr(),
+ *d.slice(dst_o..).device_ptr(),
+ "copy2d_i64",
+ ),
+ (S::BF16(s), S::BF16(d)) => (
+ *s.slice(src_o..).device_ptr(),
+ *d.slice(dst_o..).device_ptr(),
+ "copy2d_bf16",
+ ),
+ (S::F16(s), S::F16(d)) => (
+ *s.slice(src_o..).device_ptr(),
+ *d.slice(dst_o..).device_ptr(),
+ "copy2d_f16",
+ ),
+ (S::F32(s), S::F32(d)) => (
+ *s.slice(src_o..).device_ptr(),
+ *d.slice(dst_o..).device_ptr(),
+ "copy2d_f32",
+ ),
+ (S::F64(s), S::F64(d)) => (
+ *s.slice(src_o..).device_ptr(),
+ *d.slice(dst_o..).device_ptr(),
+ "copy2d_f64",
+ ),
+ _ => Err(CudaError::InternalError("dtype mismatch in copy2d"))?,
+ };
+ let func = dev.get_or_load_func(kname, kernels::FILL)?;
+ let cfg = LaunchConfig::for_num_elems(d1 * d2);
+ let params = (src, dst, d1, d2, src_s, dst_s);
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(())
+ }
+
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let src_shape = src_l.shape();
let dims = src_shape.dims();
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index 34c5d97f..43d55fa4 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -154,6 +154,19 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
+ fn copy2d(
+ &self,
+ _: &mut Self,
+ _: usize,
+ _: usize,
+ _: usize,
+ _: usize,
+ _: usize,
+ _: usize,
+ ) -> Result<()> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs
index e9d92331..791ec153 100644
--- a/candle-core/src/dummy_metal_backend.rs
+++ b/candle-core/src/dummy_metal_backend.rs
@@ -166,6 +166,19 @@ impl crate::backend::BackendStorage for MetalStorage {
Err(Error::NotCompiledWithMetalSupport)
}
+ fn copy2d(
+ &self,
+ _: &mut Self,
+ _: usize,
+ _: usize,
+ _: usize,
+ _: usize,
+ _: usize,
+ _: usize,
+ ) -> Result<()> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index fcc17afc..31ef1169 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -67,6 +67,7 @@ pub mod shape;
mod storage;
mod strided_index;
mod tensor;
+mod tensor_cat;
pub mod test_utils;
pub mod utils;
mod variable;
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index a17b87b8..2e07cce5 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -422,6 +422,7 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype {
DType::F32 => "powf_f32",
DType::F16 => "powf_f16",
+ DType::BF16 => "powf_bf16",
dtype => crate::bail!("Metal contiguous powf {dtype:?} not implemented"),
};
candle_metal_kernels::call_powf(
@@ -439,6 +440,7 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype {
DType::F32 => "powf_f32_strided",
DType::F16 => "powf_f16_strided",
+ DType::BF16 => "powf_bf16_strided",
dtype => crate::bail!("Metal strided powf {dtype:?} not implemented"),
};
candle_metal_kernels::call_powf_strided(
@@ -471,6 +473,7 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype {
DType::F32 => "elu_f32",
DType::F16 => "elu_f16",
+ DType::BF16 => "elu_bf16",
dtype => crate::bail!("Metal contiguous elu {dtype:?} not implemented"),
};
candle_metal_kernels::call_elu(
@@ -488,6 +491,7 @@ impl BackendStorage for MetalStorage {
let name = match self.dtype {
DType::F32 => "elu_f32_strided",
DType::F16 => "elu_f16_strided",
+ DType::BF16 => "elu_bf16_strided",
dtype => crate::bail!("Metal strided elu {dtype:?} not implemented"),
};
candle_metal_kernels::call_elu_strided(
@@ -1292,6 +1296,67 @@ impl BackendStorage for MetalStorage {
))
}
+ fn copy2d(
+ &self,
+ dst: &mut Self,
+ d1: usize,
+ d2: usize,
+ src_s: usize,
+ dst_s: usize,
+ src_o: usize,
+ dst_o: usize,
+ ) -> Result<()> {
+ if self.dtype() != dst.dtype() {
+ crate::bail!(
+ "copy2d with inconsistent dtypes {:?} {:?}",
+ self.dtype(),
+ dst.dtype()
+ )
+ }
+ let command_buffer = self.device.command_buffer()?;
+ if src_s == d2 && dst_s == d2 {
+ command_buffer.set_label("copy2d_contiguous");
+ let blit = command_buffer.new_blit_command_encoder();
+ blit.set_label("copy2d_contiguous");
+ let src_offset = (src_o * self.dtype.size_in_bytes()) as NSUInteger;
+ let length = (d1 * d2 * self.dtype.size_in_bytes()) as NSUInteger;
+ let dst_offset = (dst_o * dst.dtype().size_in_bytes()) as NSUInteger;
+ blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
+ blit.end_encoding();
+ } else {
+ let el_count = d1 * d2;
+ if el_count == 0 {
+ return Ok(());
+ }
+ let kernel_name = match self.dtype {
+ DType::F32 => candle_metal_kernels::copy2d::FLOAT,
+ DType::F16 => candle_metal_kernels::copy2d::HALF,
+ DType::BF16 => candle_metal_kernels::copy2d::BFLOAT,
+ DType::I64 => candle_metal_kernels::copy2d::I64,
+ DType::U32 => candle_metal_kernels::copy2d::U32,
+ DType::U8 => candle_metal_kernels::copy2d::U8,
+ dtype => crate::bail!("Metal copy2d {dtype:?} not implemented"),
+ };
+ candle_metal_kernels::call_copy2d(
+ &self.device.device,
+ &command_buffer,
+ &self.device.kernels,
+ kernel_name,
+ &self.buffer,
+ &dst.buffer,
+ d1,
+ d2,
+ src_s,
+ dst_s,
+ src_o * self.dtype.size_in_bytes(),
+ dst_o * self.dtype.size_in_bytes(),
+ )
+ .map_err(MetalError::from)?;
+ command_buffer.set_label("copy2d");
+ }
+ Ok(())
+ }
+
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let command_buffer = self.device.command_buffer()?;
if src_l.is_contiguous() && self.dtype == dst.dtype() {
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index 65bcc6aa..3bd4b022 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -701,4 +701,32 @@ impl Storage {
.bt()),
}
}
+
+ #[allow(clippy::too_many_arguments)]
+ pub(crate) fn copy2d(
+ &self,
+ dst: &mut Self,
+ d1: usize,
+ d2: usize,
+ src_s: usize,
+ dst_s: usize,
+ src_o: usize,
+ dst_o: usize,
+ ) -> Result<()> {
+ match (self, dst) {
+ (Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
+ (Self::Cuda(src), Self::Cuda(dst)) => {
+ Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
+ }
+ (Self::Metal(src), Self::Metal(dst)) => {
+ Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
+ }
+ (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
+ lhs: lhs.device().location(),
+ rhs: rhs.device().location(),
+ op: "copy2d",
+ }
+ .bt()),
+ }
+ }
}
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 0e2c3e8f..22cd4950 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -666,7 +666,7 @@ impl Tensor {
Ok(from_storage(storage, self.shape(), op, false))
}
- fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
+ pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
if dim >= self.dims().len() {
Err(Error::DimOutOfRange {
shape: self.shape().clone(),
@@ -2149,152 +2149,6 @@ impl Tensor {
Self::cat(&args, dim)
}
- /// Concatenates two or more tensors along a particular dimension.
- ///
- /// All tensors must of the same rank, and the output will have
- /// the same rank
- ///
- /// ```rust
- /// # use candle_core::{Tensor, DType, Device};
- /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
- /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
- ///
- /// let c = Tensor::cat(&[&a, &b], 0)?;
- /// assert_eq!(c.shape().dims(), &[4, 3]);
- ///
- /// let c = Tensor::cat(&[&a, &b], 1)?;
- /// assert_eq!(c.shape().dims(), &[2, 6]);
- /// # Ok::<(), candle_core::Error>(())
- /// ```
- pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
- if args.is_empty() {
- Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
- }
- let arg0 = args[0].as_ref();
- if args.len() == 1 {
- return Ok(arg0.clone());
- }
- let dim = dim.to_index(arg0.shape(), "cat")?;
- for arg in args {
- arg.as_ref().check_dim(dim, "cat")?;
- }
- for (arg_idx, arg) in args.iter().enumerate() {
- let arg = arg.as_ref();
- if arg0.rank() != arg.rank() {
- Err(Error::UnexpectedNumberOfDims {
- expected: arg0.rank(),
- got: arg.rank(),
- shape: arg.shape().clone(),
- }
- .bt())?
- }
- for (dim_idx, (v1, v2)) in arg0
- .shape()
- .dims()
- .iter()
- .zip(arg.shape().dims().iter())
- .enumerate()
- {
- if dim_idx != dim && v1 != v2 {
- Err(Error::ShapeMismatchCat {
- dim: dim_idx,
- first_shape: arg0.shape().clone(),
- n: arg_idx + 1,
- nth_shape: arg.shape().clone(),
- }
- .bt())?
- }
- }
- }
- if dim == 0 {
- Self::cat0(args)
- } else {
- // TODO: Avoid these transpositions and have an implementation that works
- // for dim != 0...
- let args: Vec<Tensor> = args
- .iter()
- .map(|a| a.as_ref().transpose(0, dim))
- .collect::<Result<Vec<_>>>()?;
- let cat = Self::cat0(&args)?;
- cat.transpose(0, dim)
- }
- }
-
- fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
- if args.is_empty() {
- Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
- }
- let arg0 = args[0].as_ref();
- if args.len() == 1 {
- return Ok(arg0.clone());
- }
- let rank = arg0.rank();
- let device = arg0.device();
- let dtype = arg0.dtype();
- let first_dims = arg0.shape().dims();
- let mut cat_dims = first_dims.to_vec();
- cat_dims[0] = 0;
- let mut offsets = vec![0usize];
- for (arg_idx, arg) in args.iter().enumerate() {
- let arg = arg.as_ref();
- if arg.dtype() != dtype {
- Err(Error::DTypeMismatchBinaryOp {
- lhs: dtype,
- rhs: arg.dtype(),
- op: "cat",
- }
- .bt())?
- }
- if arg.device().location() != device.location() {
- Err(Error::DeviceMismatchBinaryOp {
- lhs: device.location(),
- rhs: arg.device().location(),
- op: "cat",
- }
- .bt())?
- }
- if rank != arg.rank() {
- Err(Error::UnexpectedNumberOfDims {
- expected: rank,
- got: arg.rank(),
- shape: arg.shape().clone(),
- }
- .bt())?
- }
- for (dim_idx, (v1, v2)) in arg0
- .shape()
- .dims()
- .iter()
- .zip(arg.shape().dims().iter())
- .enumerate()
- {
- if dim_idx == 0 {
- cat_dims[0] += v2;
- }
- if dim_idx != 0 && v1 != v2 {
- Err(Error::ShapeMismatchCat {
- dim: dim_idx,
- first_shape: arg0.shape().clone(),
- n: arg_idx + 1,
- nth_shape: arg.shape().clone(),
- }
- .bt())?
- }
- }
- let next_offset = offsets.last().unwrap() + arg.elem_count();
- offsets.push(next_offset);
- }
- let shape = Shape::from(cat_dims);
- let op = BackpropOp::new(args, |args| Op::Cat(args, 0));
- let mut storage = device.zeros(&shape, dtype)?;
- for (arg, &offset) in args.iter().zip(offsets.iter()) {
- let arg = arg.as_ref();
- arg.storage()
- .copy_strided_src(&mut storage, offset, arg.layout())?;
- }
- Ok(from_storage(storage, shape, op, false))
- }
-
/// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
/// input tensor values and `right` elements after.
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
diff --git a/candle-core/src/tensor_cat.rs b/candle-core/src/tensor_cat.rs
new file mode 100644
index 00000000..25acc80e
--- /dev/null
+++ b/candle-core/src/tensor_cat.rs
@@ -0,0 +1,240 @@
+use crate::{shape::Dim, Error, Result, Shape, Tensor};
+
+impl Tensor {
+ /// Concatenates two or more tensors along a particular dimension.
+ ///
+ /// All tensors must of the same rank, and the output will have
+ /// the same rank
+ ///
+ /// ```rust
+ /// # use candle_core::{Tensor, DType, Device};
+ /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
+ /// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
+ ///
+ /// let c = Tensor::cat(&[&a, &b], 0)?;
+ /// assert_eq!(c.shape().dims(), &[4, 3]);
+ ///
+ /// let c = Tensor::cat(&[&a, &b], 1)?;
+ /// assert_eq!(c.shape().dims(), &[2, 6]);
+ /// # Ok::<(), candle_core::Error>(())
+ /// ```
+ pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
+ if args.is_empty() {
+ Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
+ }
+ let arg0 = args[0].as_ref();
+ if args.len() == 1 {
+ return Ok(arg0.clone());
+ }
+ let dim = dim.to_index(arg0.shape(), "cat")?;
+ for arg in args {
+ arg.as_ref().check_dim(dim, "cat")?;
+ }
+ for (arg_idx, arg) in args.iter().enumerate() {
+ let arg = arg.as_ref();
+ if arg0.rank() != arg.rank() {
+ Err(Error::UnexpectedNumberOfDims {
+ expected: arg0.rank(),
+ got: arg.rank(),
+ shape: arg.shape().clone(),
+ }
+ .bt())?
+ }
+ for (dim_idx, (v1, v2)) in arg0
+ .shape()
+ .dims()
+ .iter()
+ .zip(arg.shape().dims().iter())
+ .enumerate()
+ {
+ if dim_idx != dim && v1 != v2 {
+ Err(Error::ShapeMismatchCat {
+ dim: dim_idx,
+ first_shape: arg0.shape().clone(),
+ n: arg_idx + 1,
+ nth_shape: arg.shape().clone(),
+ }
+ .bt())?
+ }
+ }
+ }
+ if dim == 0 {
+ Self::cat0(args)
+ } else {
+ let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
+ if all_contiguous {
+ Self::cat_contiguous(args, dim)
+ } else {
+ let args: Vec<Tensor> = args
+ .iter()
+ .map(|a| a.as_ref().transpose(0, dim))
+ .collect::<Result<Vec<_>>>()?;
+ let cat = Self::cat0(&args)?;
+ cat.transpose(0, dim)
+ }
+ }
+ }
+
+ fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
+ if args.is_empty() {
+ Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
+ }
+ let arg0 = args[0].as_ref();
+ if args.len() == 1 {
+ return Ok(arg0.clone());
+ }
+ let rank = arg0.rank();
+ let device = arg0.device();
+ let dtype = arg0.dtype();
+ let first_dims = arg0.shape().dims();
+ let mut cat_dims = first_dims.to_vec();
+ cat_dims[0] = 0;
+ let mut offsets = vec![0usize];
+ for (arg_idx, arg) in args.iter().enumerate() {
+ let arg = arg.as_ref();
+ if arg.dtype() != dtype {
+ Err(Error::DTypeMismatchBinaryOp {
+ lhs: dtype,
+ rhs: arg.dtype(),
+ op: "cat",
+ }
+ .bt())?
+ }
+ if arg.device().location() != device.location() {
+ Err(Error::DeviceMismatchBinaryOp {
+ lhs: device.location(),
+ rhs: arg.device().location(),
+ op: "cat",
+ }
+ .bt())?
+ }
+ if rank != arg.rank() {
+ Err(Error::UnexpectedNumberOfDims {
+ expected: rank,
+ got: arg.rank(),
+ shape: arg.shape().clone(),
+ }
+ .bt())?
+ }
+ for (dim_idx, (v1, v2)) in arg0
+ .shape()
+ .dims()
+ .iter()
+ .zip(arg.shape().dims().iter())
+ .enumerate()
+ {
+ if dim_idx == 0 {
+ cat_dims[0] += v2;
+ }
+ if dim_idx != 0 && v1 != v2 {
+ Err(Error::ShapeMismatchCat {
+ dim: dim_idx,
+ first_shape: arg0.shape().clone(),
+ n: arg_idx + 1,
+ nth_shape: arg.shape().clone(),
+ }
+ .bt())?
+ }
+ }
+ let next_offset = offsets.last().unwrap() + arg.elem_count();
+ offsets.push(next_offset);
+ }
+ let shape = Shape::from(cat_dims);
+ let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
+ let mut storage = device.zeros(&shape, dtype)?;
+ for (arg, &offset) in args.iter().zip(offsets.iter()) {
+ let arg = arg.as_ref();
+ arg.storage()
+ .copy_strided_src(&mut storage, offset, arg.layout())?;
+ }
+ Ok(crate::tensor::from_storage(storage, shape, op, false))
+ }
+
+ fn cat_contiguous<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
+ if args.is_empty() {
+ Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
+ }
+ let arg0 = args[0].as_ref();
+ if args.len() == 1 {
+ return Ok(arg0.clone());
+ }
+ let rank = arg0.rank();
+ let device = arg0.device();
+ let dtype = arg0.dtype();
+ let first_dims = arg0.shape().dims();
+ let mut cat_dims = first_dims.to_vec();
+ cat_dims[dim] = 0;
+ for (arg_idx, arg) in args.iter().enumerate() {
+ let arg = arg.as_ref();
+ if arg.dtype() != dtype {
+ Err(Error::DTypeMismatchBinaryOp {
+ lhs: dtype,
+ rhs: arg.dtype(),
+ op: "cat",
+ }
+ .bt())?
+ }
+ if arg.device().location() != device.location() {
+ Err(Error::DeviceMismatchBinaryOp {
+ lhs: device.location(),
+ rhs: arg.device().location(),
+ op: "cat",
+ }
+ .bt())?
+ }
+ if rank != arg.rank() {
+ Err(Error::UnexpectedNumberOfDims {
+ expected: rank,
+ got: arg.rank(),
+ shape: arg.shape().clone(),
+ }
+ .bt())?
+ }
+ for (dim_idx, (v1, v2)) in arg0
+ .shape()
+ .dims()
+ .iter()
+ .zip(arg.shape().dims().iter())
+ .enumerate()
+ {
+ if dim_idx == dim {
+ cat_dims[dim] += v2;
+ }
+ if dim_idx != dim && v1 != v2 {
+ Err(Error::ShapeMismatchCat {
+ dim: dim_idx,
+ first_shape: arg0.shape().clone(),
+ n: arg_idx + 1,
+ nth_shape: arg.shape().clone(),
+ }
+ .bt())?
+ }
+ }
+ }
+ let cat_target_dim_len = cat_dims[dim];
+ let block_size: usize = cat_dims.iter().skip(1 + dim).product();
+ let shape = Shape::from(cat_dims);
+ let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
+ let mut storage = device.zeros(&shape, dtype)?;
+ let mut dst_o = 0;
+ for arg in args.iter() {
+ let arg = arg.as_ref();
+ let arg_dims = arg.shape().dims();
+ let d1: usize = arg_dims.iter().take(dim).product();
+ let d2 = block_size * arg_dims[dim];
+ let dst_s = block_size * cat_target_dim_len;
+ let src_o = arg.layout().start_offset();
+ arg.storage().copy2d(
+ &mut storage,
+ d1,
+ d2,
+ /* src_s */ d2,
+ dst_s,
+ src_o,
+ dst_o,
+ )?;
+ dst_o += d2;
+ }
+ Ok(crate::tensor::from_storage(storage, shape, op, false))
+ }
+}
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs
index f0f1b7f2..ba60b778 100644
--- a/candle-core/tests/conv_tests.rs
+++ b/candle-core/tests/conv_tests.rs
@@ -53,6 +53,12 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
+
+ // conv-transposes are not implemented for metal.
+ if dev.is_metal() {
+ return Ok(());
+ }
+
let w = w.transpose(0, 1)?;
// The CPU kernels applied in the contiguous and non contiguous cases are different.
for w in [w.clone(), w.contiguous()?] {
@@ -162,31 +168,33 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
- let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
- assert_eq!(res.dims(), [1, 2, 7, 7]);
- assert_eq!(
- test_utils::to_vec3_round(&res.i(0)?, 4)?,
- [
- [
- [-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
- [1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
- [0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
- [0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
- [-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
- [2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
- [5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
- ],
+ if !dev.is_metal() {
+ let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
+ assert_eq!(res.dims(), [1, 2, 7, 7]);
+ assert_eq!(
+ test_utils::to_vec3_round(&res.i(0)?, 4)?,
[
- [1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
- [-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
- [1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
- [-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
- [7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
- [-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
- [-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
+ [
+ [-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
+ [1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
+ [0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
+ [0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
+ [-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
+ [2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
+ [5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
+ ],
+ [
+ [1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
+ [-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
+ [1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
+ [-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
+ [7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
+ [-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
+ [-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
+ ]
]
- ]
- );
+ );
+ }
// Dilations.
let res = t.conv2d(&w, 0, 1, 2, 1)?;
assert_eq!(res.dims(), [1, 2, 1, 1]);
@@ -195,36 +203,44 @@ fn conv2d(dev: &Device) -> Result<()> {
[2.45, -2.3504],
);
- // Transpose and dilations.
- let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
- assert_eq!(res.dims(), [1, 2, 9, 9]);
- assert_eq!(
- test_utils::to_vec3_round(&res.i(0)?, 4)?,
- [
- [
- [-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
- [2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
- [-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
- [-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
- [-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
- [0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
- [-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024],
- [4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
- [5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
- ],
+ if !dev.is_metal() {
+ // Transpose and dilations.
+ let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
+ assert_eq!(res.dims(), [1, 2, 9, 9]);
+ assert_eq!(
+ test_utils::to_vec3_round(&res.i(0)?, 4)?,
[
- [1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
- [-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
- [1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
- [1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
- [1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
- [3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
- [5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
- [-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
- [-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171]
+ [
+ [-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
+ [2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
+ [-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
+ [-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
+ [-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
+ [0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
+ [
+ -2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51,
+ -3.5024
+ ],
+ [4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
+ [5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
+ ],
+ [
+ [1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
+ [-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
+ [1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
+ [1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
+ [1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
+ [3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
+ [5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
+ [-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
+ [
+ -5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827,
+ 1.0171
+ ]
+ ]
]
- ]
- );
+ );
+ }
Ok(())
}
@@ -278,6 +294,12 @@ fn conv2d_small(dev: &Device) -> Result<()> {
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
]
);
+
+ // conv-transposes are not implemented for metal
+ if dev.is_metal() {
+ return Ok(());
+ }
+
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 1, 3, 3]);
assert_eq!(
@@ -379,6 +401,10 @@ print(w.grad.shape)
print(w.grad[0])
*/
fn conv2d_grad(dev: &Device) -> Result<()> {
+ // conv-transposes are not implemented for metal
+ if dev.is_metal() {
+ return Ok(());
+ }
use candle_core::Var;
let t = Var::from_slice(
&[
diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs
index a4d81618..b8b6be8d 100644
--- a/candle-core/tests/grad_tests.rs
+++ b/candle-core/tests/grad_tests.rs
@@ -1,3 +1,4 @@
+#![allow(clippy::approx_constant)]
use anyhow::{Context, Result};
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
@@ -96,24 +97,24 @@ fn unary_grad(device: &Device) -> Result<()> {
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(
- y.to_vec1::<f32>()?,
- [20.085537, 2.7182817, 54.59815, 1.1618342]
+ test_utils::to_vec1_round(&y, 4)?,
+ [20.0855, 2.7183, 54.5982, 1.1618]
);
assert_eq!(
- grad_x.to_vec1::<f32>()?,
- [20.085537, 2.7182817, 54.59815, 1.1618342]
+ test_utils::to_vec1_round(grad_x, 4)?,
+ [20.0855, 2.7183, 54.5982, 1.1618]
);
let y = x.exp()?.sqr()?;
let grads = y.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(
- y.to_vec1::<f32>()?,
- [403.4288, 7.3890557, 2980.9578, 1.3498588]
+ test_utils::to_vec1_round(&y, 3)?,
+ [403.429, 7.389, 2980.958, 1.35]
);
// exp(x)^2 = exp(2*x)
assert_eq!(
- grad_x.to_vec1::<f32>()?,
- [806.8576, 14.778111, 5961.9155, 2.6997175]
+ test_utils::to_vec1_round(grad_x, 2)?,
+ [806.86, 14.78, 5961.92, 2.7]
);
let y = x.sin()?;
let grads = y.backward()?;
@@ -261,6 +262,7 @@ fn unary_grad(device: &Device) -> Result<()> {
let y = elu_x.elu(2.)?;
let grads = y.backward()?;
let grad_x = grads.get(&elu_x).context("no grad for x")?;
+
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[-1.2642, 0.0000, -1.7293, 3.0000]
diff --git a/candle-core/tests/pool_tests.rs b/candle-core/tests/pool_tests.rs
index a3708ec4..a6530e03 100644
--- a/candle-core/tests/pool_tests.rs
+++ b/candle-core/tests/pool_tests.rs
@@ -2,6 +2,9 @@ use candle_core::{test_device, test_utils, Device, IndexOp, Result, Tensor};
// https://github.com/huggingface/candle/issues/364
fn avg_pool2d(dev: &Device) -> Result<()> {
+ if dev.is_metal() {
+ return Ok(());
+ }
let data: Vec<f32> = vec![
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
@@ -19,6 +22,9 @@ fn avg_pool2d(dev: &Device) -> Result<()> {
}
fn max_pool2d(dev: &Device) -> Result<()> {
+ if dev.is_metal() {
+ return Ok(());
+ }
let data: Vec<f32> = vec![
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
];
@@ -43,6 +49,9 @@ res = torch.nn.functional.avg_pool2d(t, 2)
print(res)
*/
fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
+ if dev.is_metal() {
+ return Ok(());
+ }
let t = Tensor::new(
&[
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 31a27422..b2475adc 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -672,6 +672,31 @@ fn cat(device: &Device) -> Result<()> {
[2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0]
]
);
+
+ // 3D
+ let t1 = Tensor::arange(0, 48i64, device)?.reshape((2, 6, 4))?;
+ let t2 = Tensor::arange(100, 124i64, device)?.reshape((2, 3, 4))?;
+ let t3 = Tensor::arange(10000, 10032i64, device)?.reshape((2, 4, 4))?;
+
+ let t_cat = Tensor::cat(&[&t1, &t2, &t3], 1)?;
+
+ let t1 = t1.t()?.contiguous()?.t()?;
+ let t2 = t2.t()?.contiguous()?.t()?;
+ let t3 = t3.t()?.contiguous()?.t()?;
+ let t_cat2 = Tensor::cat(&[&t1, &t2, &t3], 1)?;
+
+ let diff = t_cat.eq(&t_cat2)?.to_dtype(DType::F32)?.sum_all()?;
+ assert_eq!(diff.to_vec0::<f32>()?, 104.0);
+ assert_eq!(t_cat.i((0, 0, 0))?.to_vec0::<i64>()?, 0);
+ assert_eq!(t_cat.i((0, 4, 0))?.to_vec0::<i64>()?, 16);
+ assert_eq!(t_cat.i((0, 5, 0))?.to_vec0::<i64>()?, 20);
+ assert_eq!(t_cat.i((1, 5, 0))?.to_vec0::<i64>()?, 44);
+ assert_eq!(t_cat.i((0, 6, 0))?.to_vec0::<i64>()?, 100);
+ assert_eq!(t_cat.i((1, 6, 0))?.to_vec0::<i64>()?, 112);
+ assert_eq!(t_cat.i((0, 6, 1))?.to_vec0::<i64>()?, 101);
+ assert_eq!(t_cat.i((0, 7, 1))?.to_vec0::<i64>()?, 105);
+ assert_eq!(t_cat.i((0, 12, 1))?.to_vec0::<i64>()?, 10013);
+ assert_eq!(t_cat.i((1, 12, 3))?.to_vec0::<i64>()?, 10031);
Ok(())
}
diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu
index 883ca072..ca448d98 100644
--- a/candle-kernels/src/fill.cu
+++ b/candle-kernels/src/fill.cu
@@ -10,11 +10,39 @@ __device__ void fill_with(T *buf, T value, const size_t numel) {
extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); }
extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); }
extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); }
-extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); }
extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); }
+template<typename T>
+__device__ void copy2d(const T *src, T *dst, uint32_t d1, uint32_t d2, uint32_t src_s, uint32_t dst_s) {
+ uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx >= d1 * d2) {
+ return;
+ }
+ uint32_t idx1 = idx / d2;
+ uint32_t idx2 = idx - d2 * idx1;
+ dst[idx1 * dst_s + idx2] = src[idx1 * src_s + idx2];
+}
+
+#define COPY2D_OP(TYPENAME, FNNAME) \
+extern "C" __global__ \
+void FNNAME(const TYPENAME *src, TYPENAME *dst, uint32_t d1, uint32_t d2, uint32_t src_s, uint32_t dst_s) { \
+ copy2d(src, dst, d1, d2, src_s, dst_s); \
+} \
+
+COPY2D_OP(float, copy2d_f32)
+COPY2D_OP(double, copy2d_f64)
+COPY2D_OP(uint8_t, copy2d_u8)
+COPY2D_OP(uint32_t, copy2d_u32)
+COPY2D_OP(int64_t, copy2d_i64)
+
+#if __CUDA_ARCH__ >= 530
+extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
+COPY2D_OP(__half, copy2d_f16)
+#endif
+
#if __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
+COPY2D_OP(__nv_bfloat16, copy2d_bf16)
#endif
diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal
index a4484998..76c0365a 100644
--- a/candle-metal-kernels/src/affine.metal
+++ b/candle-metal-kernels/src/affine.metal
@@ -89,7 +89,7 @@ kernel void FN_NAME( \
return; \
} \
const TYPENAME x = input[id]; \
- output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
+ output[id] = TYPENAME((x > 0)?x: mul * (exp(x) - 1)); \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 47ce7e96..a879c86a 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -127,6 +127,16 @@ pub enum Source {
Quantized,
}
+pub mod copy2d {
+ pub struct Kernel(pub &'static str);
+ pub const FLOAT: Kernel = Kernel("copy2d_f32");
+ pub const HALF: Kernel = Kernel("copy2d_f16");
+ pub const BFLOAT: Kernel = Kernel("copy2d_bf16");
+ pub const I64: Kernel = Kernel("copy2d_i64");
+ pub const U32: Kernel = Kernel("copy2d_u32");
+ pub const U8: Kernel = Kernel("copy2d_u8");
+}
+
macro_rules! ops{
($($name:ident),+) => {
@@ -366,6 +376,46 @@ pub fn call_unary_contiguous(
}
#[allow(clippy::too_many_arguments)]
+pub fn call_copy2d(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: copy2d::Kernel,
+ input: &Buffer,
+ output: &Buffer,
+ d1: usize,
+ d2: usize,
+ src_s: usize,
+ dst_s: usize,
+ src_o_in_bytes: usize,
+ dst_o_in_bytes: usize,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+ set_params!(
+ encoder,
+ (
+ d1,
+ d2,
+ src_s,
+ dst_s,
+ (input, src_o_in_bytes),
+ (output, dst_o_in_bytes)
+ )
+ );
+
+ let width: usize = d1 * d2;
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
+
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
pub fn call_unary_strided(
device: &Device,
command_buffer: &CommandBufferRef,
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index 1e0d5526..bdc13f9e 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -102,6 +102,30 @@ UNARY(NAME, half, NAME##_f16, NAME##_f16_strided);
#define BFLOAT_UNARY_OP(NAME) \
UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided);
+#define COPY2D(FN_NAME, TYPENAME) \
+kernel void FN_NAME( \
+ constant size_t &d1, \
+ constant size_t &d2, \
+ constant size_t &src_s, \
+ constant size_t &dst_s, \
+ device const TYPENAME *input, \
+ device TYPENAME *output, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ if (tid >= d1 * d2) { \
+ return; \
+ } \
+ size_t idx1 = tid / d2; \
+ size_t idx2 = tid - idx1 * d2; \
+ size_t src_idx = idx1 * src_s + idx2; \
+ size_t dst_idx = idx1 * dst_s + idx2; \
+ output[dst_idx] = input[src_idx]; \
+}
+
+COPY2D(copy2d_f32, float)
+COPY2D(copy2d_f16, half)
+COPY2D(copy2d_u8, uint8_t)
+COPY2D(copy2d_u32, uint32_t)
UNARY_OP(cos)
UNARY_OP(sin)
@@ -128,6 +152,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided)
#if __METAL_VERSION__ >= 220
UNARY(id, int64_t, copy_i64, copy_i64_strided)
+COPY2D(copy2d_i64, int64_t)
#endif
#if defined(__HAVE_BFLOAT__)
@@ -151,4 +176,6 @@ BFLOAT_UNARY_OP(recip)
BFLOAT_UNARY_OP(relu)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
+
+COPY2D(copy2d_bf64, bfloat)
#endif
diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs
index 001be116..430316b8 100644
--- a/candle-nn/examples/cpu_benchmarks.rs
+++ b/candle-nn/examples/cpu_benchmarks.rs
@@ -238,6 +238,23 @@ impl Benchmark for QMatMul {
const ITERS: usize = 100;
}
+struct Cat;
+impl Benchmark for Cat {
+ type PreProcessData = (Tensor, Tensor);
+ type RunResult = Tensor;
+ fn preprocess() -> Result<Self::PreProcessData> {
+ let lhs = Tensor::randn(0f32, 1., (1, 32, 2000, 128), &Device::Cpu)?;
+ let rhs = Tensor::randn(0f32, 1., (1, 32, 1, 128), &Device::Cpu)?;
+ Ok((lhs, rhs))
+ }
+
+ fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
+ Tensor::cat(&[&d.0, &d.1], 2)
+ }
+
+ const ITERS: usize = 1000;
+}
+
struct Softmax;
impl Benchmark for Softmax {
type PreProcessData = Tensor;
@@ -295,6 +312,7 @@ enum Task {
Qmatmul,
Softmax,
SoftmaxLastDim,
+ Cat,
}
#[derive(Parser, Debug)]
@@ -319,6 +337,7 @@ fn main() -> Result<()> {
Task::Softmax => run::<Softmax>(args.iters)?,
Task::SoftmaxLastDim => run::<SoftmaxLastDim>(args.iters)?,
Task::Qmatmul => run::<QMatMul>(args.iters)?,
+ Task::Cat => run::<Cat>(args.iters)?,
}
Ok(())
}