summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml1
-rw-r--r--candle-core/Cargo.toml3
-rw-r--r--candle-core/src/metal_backend.rs821
-rw-r--r--candle-metal-kernels/Cargo.toml19
-rw-r--r--candle-metal-kernels/README.md3
-rw-r--r--candle-metal-kernels/src/affine.metal46
-rw-r--r--candle-metal-kernels/src/binary.metal78
-rw-r--r--candle-metal-kernels/src/cast.metal58
-rw-r--r--candle-metal-kernels/src/indexing.metal75
-rw-r--r--candle-metal-kernels/src/lib.rs1246
-rw-r--r--candle-metal-kernels/src/reduce.metal124
-rw-r--r--candle-metal-kernels/src/ternary.metal57
-rw-r--r--candle-metal-kernels/src/unary.metal82
13 files changed, 2612 insertions, 1 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 0fea0423..c37bd75b 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -13,6 +13,7 @@ members = [
exclude = [
"candle-flash-attn",
"candle-kernels",
+ "candle-metal-kernels",
"candle-onnx",
]
resolver = "2"
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index 5d5e70a3..592f5bdf 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -13,6 +13,7 @@ readme = "README.md"
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true }
+candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
@@ -40,4 +41,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
-metal = ["dep:metal"]
+metal = ["dep:metal", "dep:candle-metal-kernels"]
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
new file mode 100644
index 00000000..04a2c3dd
--- /dev/null
+++ b/candle-core/src/metal_backend.rs
@@ -0,0 +1,821 @@
+use crate::backend::{BackendDevice, BackendStorage};
+use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
+use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
+use crate::{CpuStorage, DType, Layout, Result, Shape};
+use candle_metal_kernels;
+use candle_metal_kernels::{void_ptr, Kernels, Source};
+use core::mem;
+use half::{bf16, f16};
+use metal;
+use metal::mps::matrix::encode_gemm;
+use metal::mps::Float32;
+use metal::{Buffer, CommandQueue, CompileOptions, MTLResourceOptions, MTLSize, NSUInteger};
+use std::sync::Arc;
+use tracing::debug;
+
+/// Metal related errors
+#[derive(thiserror::Error, Debug)]
+pub enum MetalError {
+ #[error("{0}")]
+ Message(String),
+ #[error(transparent)]
+ KernelError(#[from] candle_metal_kernels::MetalKernelError),
+}
+
+impl From<String> for MetalError {
+ fn from(e: String) -> Self {
+ MetalError::Message(e)
+ }
+}
+
+#[derive(Clone)]
+pub struct MetalDevice {
+ device: metal::Device,
+ command_queue: metal::CommandQueue,
+ kernels: Arc<candle_metal_kernels::Kernels>,
+}
+
+impl std::fmt::Debug for MetalDevice {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "MetalDevice({:?})", self.device.registry_id())
+ }
+}
+
+impl std::ops::Deref for MetalDevice {
+ type Target = metal::DeviceRef;
+
+ fn deref(&self) -> &Self::Target {
+ &self.device
+ }
+}
+
+impl MetalDevice {
+ // pub fn metal_device(&self) -> &metal::DeviceRef {
+ // self.device.as_ref()
+ // }
+
+ pub fn id(&self) -> u64 {
+ self.registry_id()
+ }
+
+ pub fn command_queue(&self) -> &CommandQueue {
+ &self.command_queue
+ }
+
+ pub fn kernels(&self) -> &Kernels {
+ &self.kernels
+ }
+
+ pub fn device(&self) -> &metal::Device {
+ &self.device
+ }
+
+ pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
+ let size = (element_count * dtype.size_in_bytes()) as u64;
+ // debug!("Allocate 1 - buffer size {size}");
+ self.device
+ .new_buffer(size, MTLResourceOptions::StorageModeManaged)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct MetalStorage {
+ buffer: metal::Buffer,
+ device: MetalDevice,
+ dtype: DType,
+}
+
+impl BackendStorage for MetalStorage {
+ type Device = MetalDevice;
+
+ fn try_clone(&self, _: &Layout) -> Result<Self> {
+ Ok(self.clone())
+ }
+
+ fn dtype(&self) -> DType {
+ self.dtype
+ }
+
+ fn device(&self) -> &Self::Device {
+ &self.device
+ }
+
+ fn to_cpu_storage(&self) -> Result<CpuStorage> {
+ match self.dtype {
+ DType::F32 => Ok(CpuStorage::F32(
+ self.buffer.read_to_vec(self.buffer.length() as usize / 4),
+ )),
+ dtype => todo!("Unsupported dtype {dtype:?}"),
+ }
+ }
+
+ fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
+ let device = self.device().clone();
+
+ let shape = layout.shape();
+ let dims = shape.dims();
+ let el = shape.elem_count();
+ let dtype = self.dtype;
+
+ assert!(layout.is_contiguous());
+ assert_eq!(dtype, DType::F32);
+
+ let mut buffer = device.new_buffer(el, self.dtype);
+ let command_buffer = self.device.command_queue.new_command_buffer();
+ candle_metal_kernels::call_affine(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ el,
+ &self.buffer,
+ &mut buffer,
+ mul as f32,
+ add as f32,
+ )
+ .unwrap();
+ command_buffer.commit();
+ return Ok(Self {
+ buffer,
+ device: device.clone(),
+ dtype,
+ });
+ }
+
+ fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
+ todo!()
+ }
+
+ fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
+ todo!()
+ }
+
+ fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
+ // debug!("TODO reduce_op {op:?} {sum_dims:?}");
+ assert!(sum_dims.len() == 1);
+ assert!(sum_dims[0] == layout.shape().rank() - 1);
+ assert!(layout.is_contiguous());
+ let device = self.device.clone();
+ let src_stride = layout.stride();
+ let src_dims = layout.shape().dims();
+ let src_el: usize = src_dims.iter().product();
+ // Source dims and strides with the sum dims at the end.
+ let mut dims = vec![];
+ let mut stride = vec![];
+ let mut dst_el: usize = 1;
+ for (dim_idx, &d) in src_dims.iter().enumerate() {
+ if !sum_dims.contains(&dim_idx) {
+ dst_el *= d;
+ dims.push(d);
+ stride.push(src_stride[dim_idx]);
+ }
+ }
+ for &dim_idx in sum_dims.iter() {
+ dims.push(src_dims[dim_idx]);
+ stride.push(src_stride[dim_idx]);
+ }
+
+ let el_to_sum_per_block = src_el / dst_el;
+ // The reduction loop requires the shared array to be properly initialized and for
+ // this we want the number of threads to be a power of two.
+ let block_dim = usize::min(1024, el_to_sum_per_block).next_power_of_two();
+ let (name, check_empty, return_index) = match (op, self.dtype) {
+ (ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false),
+ (ReduceOp::Min, DType::F32) => ("fast_min_float", true, false),
+ (ReduceOp::Max, DType::F32) => ("fast_max_float", true, false),
+ (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true),
+ (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true),
+ _ => todo!("Reduce op for non float"),
+ };
+ if check_empty && layout.shape().elem_count() == 0 {
+ Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
+ }
+ let dtype = if return_index { DType::U32 } else { self.dtype };
+ let mut buffer = device.new_buffer(dst_el, dtype);
+ let command_buffer = self.device.command_queue.new_command_buffer();
+ candle_metal_kernels::call_reduce_contiguous(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ name,
+ src_el,
+ dst_el,
+ &self.buffer,
+ &mut buffer,
+ )
+ .map_err(MetalError::from)?;
+ command_buffer.commit();
+
+ Ok(Self {
+ buffer,
+ device,
+ dtype,
+ })
+ }
+
+ fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
+ todo!()
+ }
+
+ fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
+ let device = self.device();
+ let shape = layout.shape();
+ let dims = shape.dims();
+ let el_count = shape.elem_count();
+ let mut buffer = device.new_buffer(el_count, dtype);
+ let command_buffer = device.command_queue.new_command_buffer();
+ if layout.is_contiguous() {
+ use candle_metal_kernels::unary::contiguous;
+
+ let kernel_name = match (self.dtype, dtype) {
+ (DType::U32, DType::F32) => "cast_u32_f32",
+ (left, right) => todo!("to dtype {left:?} - {right:?}"),
+ };
+ candle_metal_kernels::call_cast_contiguous(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ kernel_name,
+ el_count,
+ &self.buffer,
+ &mut buffer,
+ )
+ .map_err(MetalError::from)?;
+ } else {
+ todo!(
+ "TODO Implement the kernel calling cast {:?}-{:?}",
+ self.dtype,
+ dtype
+ );
+ }
+
+ command_buffer.commit();
+ // command_buffer.wait_until_scheduled();
+ debug!(
+ "cast {:?} - {:?} - {:?}",
+ dtype,
+ self.buffer.length(),
+ buffer.length()
+ );
+ Ok(Self {
+ buffer,
+ device: device.clone(),
+ dtype,
+ })
+ }
+
+ fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
+ let device = self.device();
+ let dtype = self.dtype;
+ let shape = layout.shape();
+ let dims = shape.dims();
+ let el_count = shape.elem_count();
+ let mut buffer = device.new_buffer(el_count, dtype);
+ // TODO remove
+ // return Ok(Self {
+ // buffer,
+ // device: device.clone(),
+ // dtype,
+ // });
+ let command_buffer = device.command_queue.new_command_buffer();
+ if layout.is_contiguous() {
+ use candle_metal_kernels::unary::contiguous;
+
+ let kernel_name = match (B::KERNEL, dtype) {
+ ("ucos", DType::F32) => contiguous::cos::FLOAT,
+ ("usin", DType::F32) => contiguous::sin::FLOAT,
+ ("usqr", DType::F32) => contiguous::sqr::FLOAT,
+ ("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
+ ("uneg", DType::F32) => contiguous::neg::FLOAT,
+ ("uexp", DType::F32) => contiguous::exp::FLOAT,
+ (name, dtype) => todo!("Match {name} - {dtype:?}"),
+ };
+ candle_metal_kernels::call_unary_contiguous(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ kernel_name,
+ el_count,
+ &self.buffer,
+ &mut buffer,
+ )
+ .map_err(MetalError::from)?;
+ } else {
+ todo!("TODO Implement the kernel calling {}", B::KERNEL);
+ }
+
+ let start = std::time::Instant::now();
+ command_buffer.commit();
+ // command_buffer.wait_until_scheduled();
+ debug!(
+ "Unary {:?} - {:?} - {:?} - {:?}",
+ B::KERNEL,
+ start.elapsed(),
+ self.buffer.length(),
+ buffer.length()
+ );
+
+ Ok(Self {
+ buffer,
+ device: device.clone(),
+ dtype,
+ })
+ }
+
+ fn binary_impl<B: BinaryOpT>(
+ &self,
+ rhs: &Self,
+ lhs_l: &Layout,
+ rhs_l: &Layout,
+ ) -> Result<Self> {
+ let device = self.device();
+ let dtype = self.dtype;
+ let shape = lhs_l.shape();
+ let dims = shape.dims();
+ let el_count = shape.elem_count();
+ let mut buffer = device.new_buffer(el_count, dtype);
+ let command_buffer = device.command_queue.new_command_buffer();
+ if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
+ use candle_metal_kernels::binary::contiguous;
+
+ let kernel_name = match (B::KERNEL, dtype) {
+ ("add", DType::F32) => contiguous::add::FLOAT,
+ ("badd", DType::F32) => contiguous::add::FLOAT,
+ ("sub", DType::F32) => contiguous::sub::FLOAT,
+ ("bsub", DType::F32) => contiguous::sub::FLOAT,
+ ("mul", DType::F32) => contiguous::mul::FLOAT,
+ ("bmul", DType::F32) => contiguous::mul::FLOAT,
+ ("div", DType::F32) => contiguous::div::FLOAT,
+ ("bdiv", DType::F32) => contiguous::div::FLOAT,
+ (name, dtype) => todo!("Match {name} - {dtype:?}"),
+ };
+ candle_metal_kernels::call_binary_contiguous(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ kernel_name,
+ el_count,
+ &self.buffer,
+ &rhs.buffer,
+ &mut buffer,
+ )
+ .map_err(MetalError::from)?;
+ } else {
+ use candle_metal_kernels::binary::strided;
+
+ let kernel_name = match (B::KERNEL, dtype) {
+ ("badd", DType::F32) => strided::add::FLOAT,
+ ("bsub", DType::F32) => strided::sub::FLOAT,
+ ("bmul", DType::F32) => strided::mul::FLOAT,
+ ("bdiv", DType::F32) => strided::div::FLOAT,
+ (name, dtype) => todo!("Match {name} - {dtype:?}"),
+ };
+ candle_metal_kernels::call_binary_strided(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ kernel_name,
+ lhs_l.dims(),
+ &self.buffer,
+ &lhs_l.stride(),
+ lhs_l.start_offset(),
+ &rhs.buffer,
+ &rhs_l.stride(),
+ rhs_l.start_offset(),
+ &mut buffer,
+ )
+ .map_err(MetalError::from)?;
+ }
+
+ let start = std::time::Instant::now();
+ command_buffer.commit();
+ // command_buffer.wait_until_scheduled();
+ debug!(
+ "Binary {:?} - {:?} - {:?} - {:?}",
+ B::KERNEL,
+ start.elapsed(),
+ self.buffer.length(),
+ buffer.length()
+ );
+
+ Ok(Self {
+ buffer,
+ device: device.clone(),
+ dtype,
+ })
+ }
+
+ fn where_cond(
+ &self,
+ layout: &Layout,
+ t: &Self,
+ t_l: &Layout,
+ f: &Self,
+ f_l: &Layout,
+ ) -> Result<Self> {
+ let device = self.device.clone();
+ let shape = t_l.shape();
+ let dims = shape.dims();
+ let el = shape.elem_count();
+ let dtype = t.dtype;
+ let mut buffer = self.device.new_buffer(el, dtype);
+ let command_buffer = self.device.command_queue.new_command_buffer();
+ candle_metal_kernels::call_where_cond_strided(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ "where_u8_f32",
+ &dims,
+ &self.buffer,
+ (layout.stride(), layout.start_offset()),
+ &t.buffer,
+ (&t_l.stride(), t_l.start_offset()),
+ &f.buffer,
+ (&f_l.stride(), f_l.start_offset()),
+ &mut buffer,
+ )
+ .map_err(MetalError::from)?;
+ command_buffer.commit();
+ Ok(Self {
+ buffer,
+ device,
+ dtype,
+ })
+ }
+
+ fn conv1d(
+ &self,
+ _l: &Layout,
+ _kernel: &Self,
+ _kernel_l: &Layout,
+ _params: &ParamsConv1D,
+ ) -> Result<Self> {
+ todo!()
+ }
+
+ fn conv2d(
+ &self,
+ _l: &Layout,
+ _kernel: &Self,
+ _kernel_l: &Layout,
+ _params: &ParamsConv2D,
+ ) -> Result<Self> {
+ todo!()
+ }
+
+ fn conv_transpose2d(
+ &self,
+ _l: &Layout,
+ _kernel: &Self,
+ _kernel_l: &Layout,
+ _params: &ParamsConvTranspose2D,
+ ) -> Result<Self> {
+ todo!()
+ }
+
+ fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
+ todo!()
+ }
+
+ fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
+ todo!()
+ }
+
+ fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
+ todo!()
+ }
+
+ fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
+ todo!()
+ }
+
+ fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
+ todo!()
+ }
+
+ fn scatter_add(
+ &self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: usize,
+ ) -> Result<Self> {
+ todo!()
+ }
+
+ fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
+ debug!(
+ "TODO Index select {:?} {:?} {src_l:?} {ids_l:?} {dim:?}",
+ self.buffer.length(),
+ ids.buffer.length(),
+ );
+ let src = self;
+ let ids_shape = ids_l.shape();
+ let ids_dims = ids_shape.dims();
+ // let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
+ // let src = match src_l.contiguous_offsets() {
+ // Some((o1, o2)) => src.slice(o1..o2),
+ // None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?,
+ // };
+ let left_size: usize = src_l.dims()[..dim].iter().product();
+ let right_size: usize = src_l.dims()[dim + 1..].iter().product();
+ let src_dim_size = src_l.dims()[dim];
+ let ids_dim_size = ids_shape.elem_count();
+ let dst_el = ids_shape.elem_count() * left_size * right_size;
+ let dtype = self.dtype;
+ let device = self.device();
+ let buffer = device.new_buffer(dst_el, dtype);
+ Ok(Self {
+ buffer,
+ device: device.clone(),
+ dtype,
+ })
+ // todo!()
+ }
+
+ fn index_add(
+ &self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: usize,
+ ) -> Result<Self> {
+ todo!()
+ }
+
+ fn matmul(
+ &self,
+ rhs: &Self,
+ (b, m, n, k): (usize, usize, usize, usize),
+ lhs_l: &Layout,
+ rhs_l: &Layout,
+ ) -> Result<Self> {
+ let transpose_left = false;
+ let transpose_right = !rhs_l.is_contiguous();
+ let alpha = 1.0;
+ let beta = 0.0;
+ self.matmul_generic(
+ rhs,
+ (b, m, n, k),
+ lhs_l,
+ rhs_l,
+ transpose_left,
+ transpose_right,
+ alpha,
+ beta,
+ )
+ }
+
+ fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
+ let src_shape = src_l.shape();
+ let dims = src_shape.dims();
+ let el_count = src_shape.elem_count();
+ if el_count == 0 {
+ return Ok(());
+ }
+ if src_l.is_contiguous() {
+ let command_buffer = self.device.command_queue.new_command_buffer();
+ let blip = command_buffer.new_blit_command_encoder();
+ blip.copy_from_buffer(
+ &self.buffer,
+ src_l.start_offset() as u64,
+ &dst.buffer,
+ dst_offset as u64,
+ self.buffer.length(),
+ );
+ } else {
+ let command_buffer = self.device.command_queue.new_command_buffer();
+ let kernel_name = match self.dtype {
+ DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
+ DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
+ DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
+ dtype => todo!("copy_strided not implemented for {dtype:?}"),
+ };
+ candle_metal_kernels::call_unary_strided(
+ &self.device.device,
+ &command_buffer,
+ &self.device.kernels,
+ kernel_name,
+ src_l.dims(),
+ &self.buffer,
+ &src_l.stride(),
+ src_l.start_offset(),
+ &mut dst.buffer,
+ dst_offset,
+ )
+ .map_err(MetalError::from)?;
+ command_buffer.commit();
+ }
+ Ok(())
+ }
+}
+
+impl MetalStorage {
+ pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self {
+ Self {
+ buffer,
+ device,
+ dtype,
+ }
+ }
+ pub(crate) fn matmul_generic(
+ &self,
+ rhs: &Self,
+ (b, m, n, k): (usize, usize, usize, usize),
+ lhs_l: &Layout,
+ rhs_l: &Layout,
+ transpose_left: bool,
+ transpose_right: bool,
+ alpha: f64,
+ beta: f64,
+ ) -> Result<Self> {
+ let elem_count = b * m * n;
+ match (self.dtype, rhs.dtype) {
+ (DType::F32, DType::F32) => {
+ let mut out_buffer = self.device.new_buffer(elem_count, self.dtype);
+ if b != 1 {
+ debug!("TODO implement batched matmul for B={b}");
+ // bail!("Didn't implemented strided matmul yet");
+ return Ok(Self {
+ buffer: out_buffer,
+ device: self.device.clone(),
+ dtype: self.dtype(),
+ });
+ }
+ if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
+ debug!(
+ "TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}",
+ lhs_l.is_contiguous(),
+ rhs_l.is_contiguous(),
+ rhs_l
+ );
+ return Ok(Self {
+ buffer: out_buffer,
+ device: self.device.clone(),
+ dtype: self.dtype(),
+ });
+ }
+
+ debug!("TODO GEMM");
+ let command_buffer = self.device.command_queue.new_command_buffer();
+ encode_gemm::<Float32, Float32, Float32>(
+ &self.device,
+ &command_buffer,
+ transpose_left,
+ transpose_right,
+ &self.buffer,
+ &rhs.buffer,
+ &mut out_buffer,
+ m as NSUInteger,
+ n as NSUInteger,
+ k as NSUInteger,
+ alpha as f32,
+ beta as f32,
+ Some(b as NSUInteger),
+ )
+ .map_err(MetalError::from)?;
+
+ command_buffer.commit();
+ // command_buffer.wait_until_scheduled();
+
+ Ok(Self {
+ buffer: out_buffer,
+ device: self.device.clone(),
+ dtype: self.dtype(),
+ })
+ }
+ _ => todo!("Unimplemented matmul for this pair"),
+ }
+ }
+
+ pub fn buffer(&self) -> &Buffer {
+ &self.buffer
+ }
+}
+
+impl BackendDevice for MetalDevice {
+ type Storage = MetalStorage;
+
+ fn new(ordinal: usize) -> Result<Self> {
+ let device = metal::Device::all().swap_remove(ordinal);
+
+ // let capture = metal::CaptureManager::shared();
+ // let descriptor = metal::CaptureDescriptor::new();
+ // descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
+ // descriptor.set_capture_device(&device);
+ // let mut dir = std::env::current_dir()?;
+ // dir.push("out.gputrace");
+ // descriptor.set_output_url(dir);
+
+ // capture
+ // .start_capture(&descriptor)
+ // .map_err(MetalError::from)?;
+ let command_queue = device.new_command_queue();
+ // let command_buffer = _command_queue.new_owned_command_buffer();
+ let kernels = Arc::new(Kernels::new());
+ Ok(Self {
+ device,
+ command_queue,
+ // command_buffer,
+ kernels,
+ })
+ }
+
+ fn set_seed(&self, _seed: u64) -> Result<()> {
+ todo!("set_seed")
+ }
+
+ fn location(&self) -> crate::DeviceLocation {
+ crate::DeviceLocation::Metal
+ }
+
+ fn same_device(&self, rhs: &Self) -> bool {
+ self.device.registry_id() == rhs.device.registry_id()
+ }
+
+ fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
+ // TODO Is there a faster way ?
+ let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?;
+ self.storage_from_cpu_storage(&cpu_storage)
+ }
+
+ fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
+ // TODO Is there a faster way ?
+ let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
+ self.storage_from_cpu_storage(&cpu_storage)
+ }
+
+ fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
+ let option = metal::MTLResourceOptions::StorageModeManaged;
+ let buffer = match storage {
+ CpuStorage::U8(storage) => self.device.new_buffer_with_data(
+ storage.as_ptr() as *const core::ffi::c_void,
+ (storage.len() * mem::size_of::<u8>()) as u64,
+ option,
+ ),
+ CpuStorage::U32(storage) => self.device.new_buffer_with_data(
+ storage.as_ptr() as *const core::ffi::c_void,
+ (storage.len() * mem::size_of::<u32>()) as u64,
+ option,
+ ),
+ CpuStorage::I64(storage) => self.device.new_buffer_with_data(
+ storage.as_ptr() as *const core::ffi::c_void,
+ (storage.len() * mem::size_of::<i64>()) as u64,
+ option,
+ ),
+ CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
+ storage.as_ptr() as *const core::ffi::c_void,
+ (storage.len() * mem::size_of::<bf16>()) as u64,
+ option,
+ ),
+ CpuStorage::F16(storage) => self.device.new_buffer_with_data(
+ storage.as_ptr() as *const core::ffi::c_void,
+ (storage.len() * mem::size_of::<f16>()) as u64,
+ option,
+ ),
+ CpuStorage::F32(storage) => self.device.new_buffer_with_data(
+ storage.as_ptr() as *const core::ffi::c_void,
+ (storage.len() * mem::size_of::<f32>()) as u64,
+ option,
+ ),
+ CpuStorage::F64(storage) => self.device.new_buffer_with_data(
+ storage.as_ptr() as *const core::ffi::c_void,
+ (storage.len() * mem::size_of::<f64>()) as u64,
+ option,
+ ),
+ };
+ // debug!("Allocate 2 - buffer size {}", buffer.length());
+ Ok(Self::Storage {
+ buffer,
+ device: self.clone(),
+ dtype: storage.dtype(),
+ })
+ }
+
+ fn rand_uniform(
+ &self,
+ shape: &Shape,
+ dtype: DType,
+ mean: f64,
+ stddev: f64,
+ ) -> Result<Self::Storage> {
+ // TODO is there a better way ?
+ let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
+ self.storage_from_cpu_storage(&cpu_storage)
+ }
+
+ fn rand_normal(
+ &self,
+ shape: &Shape,
+ dtype: DType,
+ mean: f64,
+ stddev: f64,
+ ) -> Result<Self::Storage> {
+ // TODO is there a better way ?
+ let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
+ self.storage_from_cpu_storage(&cpu_storage)
+ }
+}
diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml
new file mode 100644
index 00000000..ff5ede1a
--- /dev/null
+++ b/candle-metal-kernels/Cargo.toml
@@ -0,0 +1,19 @@
+[package]
+name = "candle-metal-kernels"
+version = "0.3.0"
+edition = "2021"
+
+description = "CUDA kernels for Candle"
+repository = "https://github.com/huggingface/candle"
+keywords = ["blas", "tensor", "machine-learning"]
+categories = ["science"]
+license = "MIT OR Apache-2.0"
+
+[dependencies]
+metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
+once_cell = "1.18.0"
+thiserror = "1"
+tracing = "0.1.37"
+
+[dev-dependencies]
+half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
diff --git a/candle-metal-kernels/README.md b/candle-metal-kernels/README.md
new file mode 100644
index 00000000..ec923e9a
--- /dev/null
+++ b/candle-metal-kernels/README.md
@@ -0,0 +1,3 @@
+# candle-metal-kernels
+
+This crate contains Metal kernels used from candle. \ No newline at end of file
diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal
new file mode 100644
index 00000000..c388c04e
--- /dev/null
+++ b/candle-metal-kernels/src/affine.metal
@@ -0,0 +1,46 @@
+#include <metal_stdlib>
+
+METAL_FUNC uint get_strided_index(
+ uint idx,
+ constant size_t &num_dims,
+ constant size_t *dims,
+ constant size_t *strides
+) {
+ uint strided_i = 0;
+ for (uint d = 0; d < num_dims; d++) {
+ uint dim_idx = num_dims - 1 - d;
+ strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
+ idx /= dims[dim_idx];
+ }
+ return strided_i;
+}
+
+using namespace metal;
+
+#define AFFINE(FN_NAME, TYPENAME) \
+kernel void FN_NAME( \
+ constant size_t &dim, \
+ constant float &mul, \
+ constant float &add, \
+ device const TYPENAME *input, \
+ device TYPENAME *output, \
+ uint threadgroup_size [[threads_per_threadgroup]], \
+ uint thread_index [[thread_index_in_threadgroup]] \
+) { \
+ const TYPENAME m = TYPENAME(mul); \
+ const TYPENAME a = TYPENAME(add); \
+ const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
+ const size_t start = thread_index * length; \
+ const size_t stop = min(start + length, dim); \
+ for (size_t i = start; i < stop; i++){ \
+ output[i] = input[i] * m + a; \
+ } \
+} \
+
+AFFINE(affine_float, float)
+AFFINE(affine_half, half)
+
+
+#if __METAL_VERSION__ >= 310
+AFFINE(affine_bfloat, bfloat);
+#endif
diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal
new file mode 100644
index 00000000..cfd34416
--- /dev/null
+++ b/candle-metal-kernels/src/binary.metal
@@ -0,0 +1,78 @@
+#include <metal_stdlib>
+
+METAL_FUNC uint get_strided_index(
+ uint idx,
+ constant size_t &num_dims,
+ constant size_t *dims,
+ constant size_t *strides
+) {
+ uint strided_i = 0;
+ for (uint d = 0; d < num_dims; d++) {
+ uint dim_idx = num_dims - 1 - d;
+ strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
+ idx /= dims[dim_idx];
+ }
+ return strided_i;
+}
+
+using namespace metal;
+
+#define BINARY(FN, TYPENAME, OUT_TYPENAME, FN_NAME, FN_NAME_STRIDED) \
+kernel void FN_NAME( \
+ constant size_t &dim, \
+ device const TYPENAME *left, \
+ device const TYPENAME *right, \
+ device TYPENAME *output, \
+ uint threadgroup_size [[threads_per_threadgroup]], \
+ uint thread_index [[thread_index_in_threadgroup]] \
+) { \
+ const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
+ const size_t start = thread_index * length; \
+ const size_t stop = min(start + length, dim); \
+ for (size_t i = start; i < stop; i++){ \
+ TYPENAME x = left[i]; \
+ TYPENAME y = right[i]; \
+ output[i] = OUT_TYPENAME(FN); \
+ } \
+}\
+kernel void FN_NAME_STRIDED( \
+ constant size_t &dim, \
+ constant size_t &num_dims, \
+ constant size_t *dims, \
+ constant size_t *left_strides, \
+ constant size_t *right_strides, \
+ device const TYPENAME *left, \
+ device const TYPENAME *right, \
+ device TYPENAME *output, \
+ uint threadgroup_size [[threads_per_threadgroup]], \
+ uint thread_index [[thread_index_in_threadgroup]] \
+) { \
+ const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
+ const size_t start = thread_index * length; \
+ const size_t stop = min(start + length, dim); \
+ for (size_t i = start; i < stop; i++){ \
+ TYPENAME x = left[get_strided_index(i, num_dims, dims, left_strides)]; \
+ TYPENAME y = left[get_strided_index(i, num_dims, dims, right_strides)]; \
+ output[i] = OUT_TYPENAME(FN); \
+ } \
+}
+
+#define BINARY_OP(FN, NAME) \
+BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
+BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
+
+#define BFLOAT_BINARY_OP(FN, NAME) \
+BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
+
+
+BINARY_OP(x + y, add)
+BINARY_OP(x - y, sub)
+BINARY_OP(x * y, mul)
+BINARY_OP(x / y, div)
+
+#if __METAL_VERSION__ >= 310
+BFLOAT_BINARY_OP(x + y, add)
+BFLOAT_BINARY_OP(x - y, sub)
+BFLOAT_BINARY_OP(x * y, mul)
+BFLOAT_BINARY_OP(x / y, div)
+#endif
diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal
new file mode 100644
index 00000000..52e63662
--- /dev/null
+++ b/candle-metal-kernels/src/cast.metal
@@ -0,0 +1,58 @@
+#include <metal_stdlib>
+
+METAL_FUNC uint get_strided_index(
+ uint idx,
+ constant size_t &num_dims,
+ constant size_t *dims,
+ constant size_t *strides
+) {
+ uint strided_i = 0;
+ for (uint d = 0; d < num_dims; d++) {
+ uint dim_idx = num_dims - 1 - d;
+ strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
+ idx /= dims[dim_idx];
+ }
+ return strided_i;
+}
+
+
+using namespace metal;
+
+#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \
+kernel void FN_NAME( \
+ constant size_t &dim, \
+ device const LEFT_TYPENAME *input, \
+ device RIGHT_TYPENAME *output, \
+ uint threadgroup_size [[threads_per_threadgroup]], \
+ uint thread_index [[thread_index_in_threadgroup]] \
+) { \
+ const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
+ const size_t start = thread_index * length; \
+ const size_t stop = min(start + length, dim); \
+ for (size_t i = start; i < stop; i++){ \
+ output[i] = RIGHT_TYPENAME(input[i]); \
+ } \
+} \
+kernel void FN_NAME_STRIDED( \
+ constant size_t &dim, \
+ constant size_t &num_dims, \
+ constant size_t *dims, \
+ constant size_t *strides, \
+ device const LEFT_TYPENAME *input, \
+ device RIGHT_TYPENAME *output, \
+ uint threadgroup_size [[threads_per_threadgroup]], \
+ uint thread_index [[thread_index_in_threadgroup]] \
+) { \
+ const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
+ const size_t start = thread_index * length; \
+ const size_t stop = min(start + length, dim); \
+ for (size_t i = start; i < stop; i++){ \
+ output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
+ } \
+}
+
+
+CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
+
+#if __METAL_VERSION__ >= 310
+#endif
diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal
new file mode 100644
index 00000000..528c109d
--- /dev/null
+++ b/candle-metal-kernels/src/indexing.metal
@@ -0,0 +1,75 @@
+#include <metal_stdlib>
+using namespace metal;
+
+template <typename T, typename I>
+void index_add(
+ device I *ids [[buffer(0)]],
+ device T *inp [[buffer(1)]],
+ device T *out [[buffer(2)]],
+
+ constant uint &ids_dim_size,
+ constant uint &left_size,
+ constant uint &dst_dim_size,
+ constant uint &right_size,
+
+ uint threadgroup_size [[threads_per_threadgroup]],
+ uint threadgroup_position_in_grid [[threadgroup_position_in_grid]],
+ uint thread_index [[thread_index_in_threadgroup]]
+) {
+
+ const uint gid = thread_index + (threadgroup_position_in_grid * threadgroup_size);
+ if (gid >= left_size * right_size) {
+ return;
+ }
+
+ const uint i = gid;
+ const uint pre = i / right_size;
+ const uint post = i % right_size;
+
+ for (uint j = 0; j < ids_dim_size; j++) {
+ const uint idx = ids[j];
+ const uint src_i = (pre * ids_dim_size + j) * right_size + post;
+ const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
+ out[dst_i] += inp[src_i];
+ }
+}
+
+#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
+kernel void FN_NAME( \
+ device INDEX_TYPENAME *ids [[buffer(0)]], \
+ device TYPENAME *inp [[buffer(1)]], \
+ device TYPENAME *out [[buffer(2)]], \
+ constant uint &ids_dim_size, \
+ constant uint &left_size, \
+ constant uint &dst_dim_size, \
+ constant uint &right_size, \
+ uint threadgroup_size [[threads_per_threadgroup]], \
+ uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
+ uint thread_index [[thread_index_in_threadgroup]] \
+) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \
+
+
+
+#if __METAL_VERSION__ >= 310
+IA_OP(bfloat, int64_t, ia_i64_bf16)
+IA_OP(bfloat, uint32_t, ia_u32_bf16)
+IA_OP(bfloat, uint8_t, ia_u8_bf16)
+#endif
+
+IA_OP(half, uint32_t, ia_u32_f16)
+IA_OP(half, uint8_t, ia_u8_f16)
+
+IA_OP(float, int64_t, ia_i64_f32)
+IA_OP(uint8_t, int64_t, ia_i64_u8)
+IA_OP(int64_t, int64_t, ia_i64_i64)
+IA_OP(uint32_t, int64_t, ia_i64_u32)
+
+IA_OP(float, uint32_t, ia_u32_f32)
+IA_OP(uint8_t, uint32_t, ia_u32_u8)
+IA_OP(int64_t, uint32_t, ia_u32_i64)
+IA_OP(uint32_t, uint32_t, ia_u32_u32)
+
+IA_OP(float, uint8_t, ia_u8_f32)
+IA_OP(uint8_t, uint8_t, ia_u8_u8)
+IA_OP(uint32_t, uint8_t, ia_u8_u32)
+IA_OP(int64_t, uint8_t, ia_u8_i64)
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
new file mode 100644
index 00000000..d2c63115
--- /dev/null
+++ b/candle-metal-kernels/src/lib.rs
@@ -0,0 +1,1246 @@
+#![allow(clippy::too_many_arguments)]
+use metal::{
+ Buffer, CommandBufferRef, CompileOptions, ComputePipelineDescriptor, Device, Function, Library,
+ MTLSize,
+};
+use std::collections::HashMap;
+use std::ffi::c_void;
+use std::sync::RwLock;
+
+const AFFINE: &str = include_str!("affine.metal");
+const INDEXING: &str = include_str!("indexing.metal");
+const UNARY: &str = include_str!("unary.metal");
+const BINARY: &str = include_str!("binary.metal");
+const TERNARY: &str = include_str!("ternary.metal");
+const CAST: &str = include_str!("cast.metal");
+const REDUCE: &str = include_str!("reduce.metal");
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub enum Source {
+ Affine,
+ Indexing,
+ Unary,
+ Binary,
+ Ternary,
+ Cast,
+ Reduce,
+}
+
+macro_rules! ops{
+ ($($name:ident),+) => {
+
+ pub mod contiguous {
+ pub struct Kernel(pub(crate) &'static str);
+ $(
+ pub mod $name {
+ use super::Kernel;
+ pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float"));
+ pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half"));
+ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
+ }
+ )+
+ }
+
+ pub mod strided {
+ pub struct Kernel(pub(crate) &'static str);
+ $(
+ pub mod $name {
+ use super::Kernel;
+ pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided"));
+ pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided"));
+ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
+ }
+ )+
+ }
+ };
+}
+
+pub mod unary {
+ ops!(cos, sin, exp, sqr, sqrt, neg, copy);
+}
+pub mod binary {
+ ops!(add, sub, mul, div);
+}
+
+// static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
+// let mut l = HashMap::new();
+// l.insert("affine", AFFINE);
+// l.insert("indexing", INDEXING);
+// l.insert("unary", UNARY);
+// l
+// });
+//
+#[derive(thiserror::Error, Debug)]
+pub enum MetalKernelError {
+ #[error("Could not lock kernel map: {0}")]
+ LockError(String),
+ #[error("Error while loading library: {0}")]
+ LoadLibraryError(String),
+ #[error("Error while loading function: {0}")]
+ LoadFunctionError(String),
+}
+
+impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
+ fn from(e: std::sync::PoisonError<T>) -> Self {
+ Self::LockError(e.to_string())
+ }
+}
+
+type KernelMap<T> = HashMap<&'static str, T>;
+type Libraries = HashMap<Source, Library>;
+type Functions = KernelMap<Function>;
+
+#[derive(Debug, Default)]
+pub struct Kernels {
+ libraries: RwLock<Libraries>,
+ funcs: RwLock<Functions>,
+}
+
+impl Kernels {
+ pub fn new() -> Self {
+ let libraries = RwLock::new(Libraries::new());
+ let funcs = RwLock::new(Functions::new());
+ Self { libraries, funcs }
+ }
+
+ // pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
+ // let kernels = Self::new();
+ // kernels.load_libraries(device)?;
+ // Ok(kernels)
+ // }
+
+ // fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> {
+ // for name in LIBRARY_SOURCES.keys() {
+ // self.load_library(device, name)?;
+ // }
+ // Ok(())
+ // }
+
+ fn get_library_source(&self, source: Source) -> &'static str {
+ // LIBRARY_SOURCES.get(name).cloned()
+ match source {
+ Source::Affine => AFFINE,
+ Source::Unary => UNARY,
+ Source::Binary => BINARY,
+ Source::Ternary => TERNARY,
+ Source::Indexing => INDEXING,
+ Source::Cast => CAST,
+ Source::Reduce => REDUCE,
+ }
+ }
+
+ pub fn load_library(
+ &self,
+ device: &Device,
+ source: Source,
+ ) -> Result<Library, MetalKernelError> {
+ let mut libraries = self.libraries.write()?;
+ if let Some(lib) = libraries.get(&source) {
+ Ok(lib.clone())
+ } else {
+ let source_content = self.get_library_source(source);
+ let lib = device
+ .new_library_with_source(source_content, &CompileOptions::new())
+ .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
+ libraries.insert(source, lib.clone());
+ Ok(lib)
+ }
+ }
+
+ pub fn load_function(
+ &self,
+ device: &Device,
+ source: Source,
+ name: &'static str,
+ ) -> Result<Function, MetalKernelError> {
+ let mut funcs = self.funcs.write()?;
+ if let Some(func) = funcs.get(name) {
+ Ok(func.clone())
+ } else {
+ let func = self
+ .load_library(device, source)?
+ .get_function(name, None)
+ .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
+ funcs.insert(name, func.clone());
+ Ok(func)
+ }
+ }
+}
+
+pub fn call_unary_contiguous(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ kernel_name: unary::contiguous::Kernel,
+ length: usize,
+ input: &Buffer,
+ output: &mut Buffer,
+) -> Result<(), MetalKernelError> {
+ // println!("Kernel {:?}", kernel_name.0);
+ // assert_eq!(input.length(), output.length());
+ let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
+ let pipeline_state_descriptor = ComputePipelineDescriptor::new();
+ pipeline_state_descriptor.set_compute_function(Some(&func));
+
+ let pipeline = device
+ .new_compute_pipeline_state_with_function(
+ pipeline_state_descriptor.compute_function().unwrap(),
+ )
+ .unwrap();
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ encoder.set_bytes(0, 4, void_ptr(&length));
+ encoder.set_buffer(1, Some(input), 0);
+ encoder.set_buffer(2, Some(output), 0);
+
+ let thread_group_count = MTLSize {
+ width: 1,
+ height: 1,
+ depth: 1,
+ };
+
+ let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
+ let thread_group_size = MTLSize {
+ width,
+ height: 1,
+ depth: 1,
+ };
+
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+pub fn call_unary_strided(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: unary::strided::Kernel,
+ shape: &[usize],
+ input: &Buffer,
+ strides: &[usize],
+ offset: usize,
+ output: &mut Buffer,
+ output_offset: usize,
+) -> Result<(), MetalKernelError> {
+ let func = kernels.load_function(device, Source::Unary, name.0)?;
+ let pipeline_state_descriptor = ComputePipelineDescriptor::new();
+ pipeline_state_descriptor.set_compute_function(Some(&func));
+
+ let pipeline = device
+ .new_compute_pipeline_state_with_function(
+ pipeline_state_descriptor.compute_function().unwrap(),
+ )
+ .unwrap();
+
+ let num_dims: usize = shape.len();
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ let length: usize = shape.iter().product();
+ encoder.set_bytes(0, std::mem::size_of::<usize>() as u64, void_ptr(&length));
+ encoder.set_bytes(1, std::mem::size_of::<usize>() as u64, void_ptr(&num_dims));
+ encoder.set_bytes(
+ 2,
+ std::mem::size_of_val(shape) as u64,
+ shape.as_ptr() as *const c_void,
+ );
+ encoder.set_bytes(
+ 3,
+ std::mem::size_of_val(strides) as u64,
+ strides.as_ptr() as *const c_void,
+ );
+
+ encoder.set_buffer(4, Some(input), offset as u64);
+ encoder.set_buffer(5, Some(output), output_offset as u64);
+
+ let width = output.length();
+
+ let thread_group_count = MTLSize {
+ width: 1,
+ height: 1,
+ depth: 1,
+ };
+
+ let thread_group_size = MTLSize {
+ width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width),
+ height: 1,
+ depth: 1,
+ };
+
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+pub fn call_binary_contiguous(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ kernel_name: binary::contiguous::Kernel,
+ length: usize,
+ left: &Buffer,
+ right: &Buffer,
+ output: &mut Buffer,
+) -> Result<(), MetalKernelError> {
+ // println!("Kernel {:?}", kernel_name.0);
+ // assert_eq!(input.length(), output.length());
+ let func = kernels.load_function(device, Source::Binary, kernel_name.0)?;
+ let pipeline_state_descriptor = ComputePipelineDescriptor::new();
+ pipeline_state_descriptor.set_compute_function(Some(&func));
+
+ let pipeline = device
+ .new_compute_pipeline_state_with_function(
+ pipeline_state_descriptor.compute_function().unwrap(),
+ )
+ .unwrap();
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ encoder.set_bytes(0, 4, void_ptr(&length));
+ encoder.set_buffer(1, Some(left), 0);
+ encoder.set_buffer(2, Some(right), 0);
+ encoder.set_buffer(3, Some(output), 0);
+
+ let thread_group_count = MTLSize {
+ width: 1,
+ height: 1,
+ depth: 1,
+ };
+
+ let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
+ let thread_group_size = MTLSize {
+ width,
+ height: 1,
+ depth: 1,
+ };
+
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+pub fn call_binary_strided(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: binary::strided::Kernel,
+ shape: &[usize],
+ left_input: &Buffer,
+ left_strides: &[usize],
+ left_offset: usize,
+ right_input: &Buffer,
+ right_strides: &[usize],
+ right_offset: usize,
+ output: &mut Buffer,
+) -> Result<(), MetalKernelError> {
+ let func = kernels.load_function(device, Source::Binary, name.0)?;
+ let pipeline_state_descriptor = ComputePipelineDescriptor::new();
+ pipeline_state_descriptor.set_compute_function(Some(&func));
+
+ let pipeline = device
+ .new_compute_pipeline_state_with_function(
+ pipeline_state_descriptor.compute_function().unwrap(),
+ )
+ .unwrap();
+
+ let num_dims: usize = shape.len();
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ let length: usize = shape.iter().product();
+ encoder.set_bytes(0, std::mem::size_of::<usize>() as u64, void_ptr(&length));
+ encoder.set_bytes(1, std::mem::size_of::<usize>() as u64, void_ptr(&num_dims));
+ encoder.set_bytes(
+ 2,
+ std::mem::size_of_val(shape) as u64,
+ shape.as_ptr() as *const c_void,
+ );
+ encoder.set_bytes(
+ 3,
+ std::mem::size_of_val(left_strides) as u64,
+ left_strides.as_ptr() as *const c_void,
+ );
+ encoder.set_bytes(
+ 4,
+ std::mem::size_of_val(right_strides) as u64,
+ right_strides.as_ptr() as *const c_void,
+ );
+
+ encoder.set_buffer(5, Some(left_input), left_offset as u64);
+ encoder.set_buffer(6, Some(right_input), right_offset as u64);
+ encoder.set_buffer(7, Some(output), 0);
+
+ let width = output.length();
+
+ let thread_group_count = MTLSize {
+ width: 1,
+ height: 1,
+ depth: 1,
+ };
+
+ let thread_group_size = MTLSize {
+ width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width),
+ height: 1,
+ depth: 1,
+ };
+
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+pub fn call_cast_contiguous(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ kernel_name: &'static str,
+ length: usize,
+ input: &Buffer,
+ output: &mut Buffer,
+) -> Result<(), MetalKernelError> {
+ // println!("Kernel {:?}", kernel_name.0);
+ // assert_eq!(input.length(), output.length());
+ let func = kernels.load_function(device, Source::Cast, kernel_name)?;
+ let pipeline_state_descriptor = ComputePipelineDescriptor::new();
+ pipeline_state_descriptor.set_compute_function(Some(&func));
+
+ let pipeline = device
+ .new_compute_pipeline_state_with_function(
+ pipeline_state_descriptor.compute_function().unwrap(),
+ )
+ .unwrap();
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ encoder.set_bytes(0, 4, void_ptr(&length));
+ encoder.set_buffer(1, Some(input), 0);
+ encoder.set_buffer(2, Some(output), 0);
+
+ let thread_group_count = MTLSize {
+ width: 1,
+ height: 1,
+ depth: 1,
+ };
+
+ let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
+ let thread_group_size = MTLSize {
+ width,
+ height: 1,
+ depth: 1,
+ };
+
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+pub fn call_reduce_contiguous(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ kernel_name: &'static str,
+ length: usize,
+ out_length: usize,
+ input: &Buffer,
+ output: &mut Buffer,
+) -> Result<(), MetalKernelError> {
+ let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
+ let pipeline_state_descriptor = ComputePipelineDescriptor::new();
+ pipeline_state_descriptor.set_compute_function(Some(&func));
+
+ let pipeline = device
+ .new_compute_pipeline_state_with_function(
+ pipeline_state_descriptor.compute_function().unwrap(),
+ )
+ .unwrap();
+
+ let elements_to_sum = length / out_length;
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ encoder.set_bytes(0, core::mem::size_of::<usize>() as u64, void_ptr(&length));
+ encoder.set_bytes(
+ 1,
+ core::mem::size_of::<usize>() as u64,
+ void_ptr(&elements_to_sum),
+ );
+ encoder.set_buffer(2, Some(input), 0);
+ encoder.set_buffer(3, Some(output), 0);
+
+ let thread_group_count = MTLSize {
+ width: out_length as u64,
+ height: 1,
+ depth: 1,
+ };
+
+ let width = std::cmp::min(
+ pipeline.max_total_threads_per_threadgroup(),
+ (elements_to_sum as u64 + 2 - 1) / 2,
+ )
+ .next_power_of_two();
+
+ let thread_group_size = MTLSize {
+ width,
+ height: 1,
+ depth: 1,
+ };
+
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+pub fn call_last_softmax(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ kernel_name: &'static str,
+ length: usize,
+ elements_to_sum: usize,
+ input: &Buffer,
+ output: &mut Buffer,
+) -> Result<(), MetalKernelError> {
+ let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
+ let pipeline_state_descriptor = ComputePipelineDescriptor::new();
+ pipeline_state_descriptor.set_compute_function(Some(&func));
+
+ let pipeline = device
+ .new_compute_pipeline_state_with_function(
+ pipeline_state_descriptor.compute_function().unwrap(),
+ )
+ .unwrap();
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ encoder.set_bytes(0, core::mem::size_of::<usize>() as u64, void_ptr(&length));
+ encoder.set_bytes(
+ 1,
+ core::mem::size_of::<usize>() as u64,
+ void_ptr(&elements_to_sum),
+ );
+ encoder.set_buffer(2, Some(input), 0);
+ encoder.set_buffer(3, Some(output), 0);
+
+ let out_length = length / elements_to_sum;
+
+ let thread_group_count = MTLSize {
+ width: out_length as u64,
+ height: 1,
+ depth: 1,
+ };
+
+ let width = std::cmp::min(
+ pipeline.max_total_threads_per_threadgroup(),
+ // (elements_to_sum as u64 + 2 - 1) / 2,
+ elements_to_sum as u64,
+ )
+ .next_power_of_two();
+
+ let thread_group_size = MTLSize {
+ width,
+ height: 1,
+ depth: 1,
+ };
+
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+pub fn void_ptr<T>(v: &T) -> *const c_void {
+ (v as *const T).cast()
+}
+
+pub fn call_affine(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ size: usize,
+ input: &Buffer,
+ output: &mut Buffer,
+ mul: f32,
+ add: f32,
+) -> Result<(), MetalKernelError> {
+ let func = kernels.load_function(device, Source::Affine, "affine_float")?;
+ let pipeline_state_descriptor = ComputePipelineDescriptor::new();
+ pipeline_state_descriptor.set_compute_function(Some(&func));
+
+ let pipeline = device
+ .new_compute_pipeline_state_with_function(
+ pipeline_state_descriptor.compute_function().unwrap(),
+ )
+ .unwrap();
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ encoder.set_bytes(0, core::mem::size_of::<usize>() as u64, void_ptr(&size));
+ encoder.set_bytes(1, core::mem::size_of::<f32>() as u64, void_ptr(&mul));
+ encoder.set_bytes(2, core::mem::size_of::<f32>() as u64, void_ptr(&add));
+ encoder.set_buffer(3, Some(input), 0);
+ encoder.set_buffer(4, Some(output), 0);
+
+ let thread_group_count = MTLSize {
+ width: 1,
+ height: 1,
+ depth: 1,
+ };
+
+ let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64);
+ let thread_group_size = MTLSize {
+ width,
+ height: 1,
+ depth: 1,
+ };
+
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+pub fn call_where_cond_strided(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ shape: &[usize],
+ cond: &Buffer,
+ (cond_stride, cond_offset): (&[usize], usize),
+ left: &Buffer,
+ (left_stride, left_offset): (&[usize], usize),
+ right: &Buffer,
+ (right_stride, right_offset): (&[usize], usize),
+ output: &mut Buffer,
+) -> Result<(), MetalKernelError> {
+ let func = kernels.load_function(device, Source::Ternary, name)?;
+ let pipeline_state_descriptor = ComputePipelineDescriptor::new();
+ pipeline_state_descriptor.set_compute_function(Some(&func));
+
+ let pipeline = device
+ .new_compute_pipeline_state_with_function(
+ pipeline_state_descriptor.compute_function().unwrap(),
+ )
+ .unwrap();
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ let size: usize = shape.iter().product();
+ encoder.set_bytes(0, core::mem::size_of::<usize>() as u64, void_ptr(&size));
+ encoder.set_bytes(
+ 1,
+ core::mem::size_of::<usize>() as u64,
+ void_ptr(&shape.len()),
+ );
+ encoder.set_bytes(
+ 2,
+ std::mem::size_of_val(shape) as u64,
+ shape.as_ptr() as *const c_void,
+ );
+ encoder.set_bytes(
+ 3,
+ std::mem::size_of_val(cond_stride) as u64,
+ cond_stride.as_ptr() as *const c_void,
+ );
+ encoder.set_bytes(
+ 4,
+ std::mem::size_of_val(left_stride) as u64,
+ left_stride.as_ptr() as *const c_void,
+ );
+ encoder.set_bytes(
+ 5,
+ std::mem::size_of_val(right_stride) as u64,
+ right_stride.as_ptr() as *const c_void,
+ );
+ encoder.set_buffer(6, Some(cond), cond_offset as u64);
+ encoder.set_buffer(7, Some(left), left_offset as u64);
+ encoder.set_buffer(8, Some(right), right_offset as u64);
+ encoder.set_buffer(9, Some(output), 0);
+
+ let thread_group_count = MTLSize {
+ width: 1,
+ height: 1,
+ depth: 1,
+ };
+
+ let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64);
+ let thread_group_size = MTLSize {
+ width,
+ height: 1,
+ depth: 1,
+ };
+
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use half::f16;
+ use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
+ use std::mem;
+
+ fn device() -> Device {
+ Device::system_default().unwrap()
+ }
+
+ fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
+ let b = 10f32.powi(digits);
+ v.iter().map(|t| f32::round(t * b) / b).collect()
+ }
+
+ fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
+ let b = 10f32.powi(digits);
+ v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
+ }
+
+ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
+ let device = device();
+ let kernels = Kernels::new();
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let options = MTLResourceOptions::StorageModeManaged;
+ let input = device.new_buffer_with_data(
+ v.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(v) as u64,
+ options,
+ );
+ let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options);
+ call_unary_contiguous(
+ &device,
+ command_buffer,
+ &kernels,
+ name,
+ v.len(),
+ &input,
+ &mut output,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+ output.read_to_vec::<T>(v.len())
+ }
+
+ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
+ let device = device();
+ let kernels = Kernels::new();
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let options = MTLResourceOptions::StorageModeManaged;
+ let left = device.new_buffer_with_data(
+ x.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(x) as u64,
+ options,
+ );
+ let right = device.new_buffer_with_data(
+ y.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(y) as u64,
+ options,
+ );
+ let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
+ call_binary_contiguous(
+ &device,
+ command_buffer,
+ &kernels,
+ name,
+ x.len(),
+ &left,
+ &right,
+ &mut output,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+ output.read_to_vec::<T>(x.len())
+ }
+
+ fn run_strided<T: Clone>(
+ v: &[T],
+ kernel: unary::strided::Kernel,
+ shape: &[usize],
+ strides: &[usize],
+ offset: usize,
+ ) -> Vec<T> {
+ let device = device();
+ let options = MTLResourceOptions::StorageModeManaged;
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let input = device.new_buffer_with_data(
+ v.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(v) as u64,
+ options,
+ );
+ let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options);
+ let kernels = Kernels::new();
+ call_unary_strided(
+ &device,
+ command_buffer,
+ &kernels,
+ kernel,
+ shape,
+ &input,
+ strides,
+ offset,
+ &mut output,
+ 0,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+ output.read_to_vec::<T>(v.len())
+ }
+
+ #[test]
+ fn cos_f32() {
+ let v = vec![1.0f32, 2.0, 3.0];
+ let results = run(&v, unary::contiguous::cos::FLOAT);
+ let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
+ assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
+ assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
+
+ let v = vec![1.0f32; 10_000];
+ let results = run(&v, unary::contiguous::cos::FLOAT);
+ let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
+ assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
+ assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
+ }
+
+ #[test]
+ fn cos_f32_strided() {
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
+ // Shape = [6], strides = [1];
+ let shape = vec![6];
+ let strides = vec![1];
+ let offset = 0;
+ let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
+ let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
+ assert_eq!(
+ approx(results, 4),
+ vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
+ );
+ assert_eq!(
+ approx(expected, 4),
+ vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
+ );
+
+ // Contiguous
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
+ let shape = vec![3, 2];
+ let strides = vec![2, 1];
+ let offset = 0;
+ let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
+ let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
+ assert_eq!(
+ approx(results, 4),
+ vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
+ );
+ assert_eq!(
+ approx(expected, 4),
+ vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
+ );
+
+ // Transposed
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
+ let shape = vec![3, 2];
+ let strides = vec![1, 3];
+ let offset = 0;
+ let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
+ let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
+ assert_eq!(
+ approx(results, 4),
+ vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602]
+ );
+ assert_eq!(
+ approx(expected, 4),
+ vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
+ );
+
+ // Very large
+ let v = vec![1.0f32; 10_000];
+ let shape = vec![2, 5_000];
+ let strides = vec![2, 1];
+ let offset = 0;
+ let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
+ let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
+ assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
+ assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
+ }
+
+ #[test]
+ fn binary_add_f32() {
+ let left = vec![1.0f32, 2.0, 3.0];
+ let right = vec![2.0f32, 3.1, 4.2];
+ let results = run_binary(&left, &right, binary::contiguous::add::FLOAT);
+ let expected: Vec<_> = left
+ .iter()
+ .zip(right.iter())
+ .map(|(&x, &y)| x + y)
+ .collect();
+ assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]);
+ assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
+ }
+
+ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
+ let device = device();
+ let kernels = Kernels::new();
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let options = MTLResourceOptions::StorageModeManaged;
+ let input = device.new_buffer_with_data(
+ v.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(v) as u64,
+ options,
+ );
+ let mut output = device.new_buffer((v.len() * core::mem::size_of::<U>()) as u64, options);
+ call_cast_contiguous(
+ &device,
+ command_buffer,
+ &kernels,
+ name,
+ v.len(),
+ &input,
+ &mut output,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+ output.read_to_vec::<U>(v.len())
+ }
+
+ #[test]
+ fn cast_u32_f32() {
+ let v = vec![1u32, 2, 3];
+ let results = cast(&v, "cast_u32_f32");
+ let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
+ assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
+ assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
+
+ let v = vec![1.0f32; 10_000];
+ let results = run(&v, unary::contiguous::cos::FLOAT);
+ let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
+ assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
+ assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
+ }
+
+ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
+ let device = device();
+ let kernels = Kernels::new();
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let options = MTLResourceOptions::StorageModeManaged;
+
+ let input = device.new_buffer_with_data(
+ v.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(v) as u64,
+ options,
+ );
+ let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options);
+
+ let size = v.len();
+
+ call_affine(
+ &device,
+ command_buffer,
+ &kernels,
+ size,
+ &input,
+ &mut output,
+ mul as f32,
+ add as f32,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ output.read_to_vec::<T>(v.len())
+ }
+
+ #[test]
+ fn affine() {
+ let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
+ let mul = 1.5;
+ let add = 1.1;
+ let result = run_affine(&input, mul, add);
+ assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
+
+ let input = [1.0f32; 40_000];
+ let mul = 1.5;
+ let add = 1.1;
+ let result = run_affine(&input, mul, add);
+ assert_eq!(result, vec![2.6; 40_000]);
+ }
+
+ #[test]
+ fn index_add() {
+ let device = Device::system_default().expect("no device found");
+
+ let options = CompileOptions::new();
+ let library = device.new_library_with_source(INDEXING, &options).unwrap();
+
+ let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
+ let right = [1.0f32; 15];
+ let index = [0u32, 4, 2];
+ let ids_dim_size = index.len() as u32;
+ let dst_dim_size: u32 = 15;
+ let left_size: u32 = 3;
+ let right_size: u32 = 3;
+
+ let function = library.get_function("ia_u32_f32", None).unwrap();
+ let pipeline = device
+ .new_compute_pipeline_state_with_function(&function)
+ .unwrap();
+ let options = MTLResourceOptions::StorageModeManaged;
+
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let encoder = command_buffer.new_compute_command_encoder();
+
+ let ids_size = (index.len() * mem::size_of::<u32>()) as NSUInteger;
+ let input_size = (left.len() * mem::size_of::<f32>()) as NSUInteger;
+ let output_size = (right.len() * mem::size_of::<f32>()) as NSUInteger;
+
+ encoder.set_compute_pipeline_state(&pipeline);
+ encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
+
+ let index_buffer = device.new_buffer_with_data(void_ptr(&index), ids_size, options);
+ let inputs_buffer = device.new_buffer_with_data(void_ptr(&left), input_size, options);
+ let outputs_buffer = device.new_buffer_with_data(void_ptr(&right), output_size, options);
+
+ encoder.set_buffer(0, Some(&index_buffer), 0);
+ encoder.set_buffer(1, Some(&inputs_buffer), 0);
+ encoder.set_buffer(2, Some(&outputs_buffer), 0);
+
+ encoder.set_bytes(3, 4, void_ptr(&ids_dim_size));
+ encoder.set_bytes(4, 4, void_ptr(&left_size));
+ encoder.set_bytes(5, 4, void_ptr(&dst_dim_size));
+ encoder.set_bytes(6, 4, void_ptr(&right_size));
+
+ let grid_size = MTLSize {
+ width: right.len() as NSUInteger,
+ height: 1,
+ depth: 1,
+ };
+
+ let thread_group_size = MTLSize {
+ width: pipeline.max_total_threads_per_threadgroup(),
+ height: 1,
+ depth: 1,
+ };
+
+ encoder.dispatch_thread_groups(grid_size, thread_group_size);
+ encoder.end_encoding();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ let expected = vec![
+ 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
+ ];
+ let result = outputs_buffer.read_to_vec::<f32>(right.len());
+ assert_eq!(result, expected);
+ }
+
+ #[test]
+ fn cos_f16() {
+ let v: Vec<f16> = [1.0f32, 2.0, 3.0]
+ .iter()
+ .map(|v| f16::from_f32(*v))
+ .collect();
+ let results = run(&v, unary::contiguous::cos::HALF);
+ let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
+ assert_eq!(approx_f16(results, 4), vec![0.54, -0.4165, -0.9902]);
+ assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
+ }
+
+ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
+ let device = device();
+ let kernels = Kernels::new();
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let options = MTLResourceOptions::StorageModeManaged;
+ let input = device.new_buffer_with_data(
+ v.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(v) as u64,
+ options,
+ );
+ let mut output =
+ device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
+ call_reduce_contiguous(
+ &device,
+ command_buffer,
+ &kernels,
+ name,
+ v.len(),
+ out_length,
+ &input,
+ &mut output,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ output.read_to_vec::<T>(out_length)
+ }
+
+ fn run_softmax<T: Clone + std::fmt::Debug>(
+ v: &[T],
+ last_dim: usize,
+ name: &'static str,
+ ) -> Vec<T> {
+ let device = device();
+ let kernels = Kernels::new();
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let options = MTLResourceOptions::StorageModeManaged;
+ let input = device.new_buffer_with_data(
+ v.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(v) as u64,
+ options,
+ );
+ let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options);
+ call_last_softmax(
+ &device,
+ command_buffer,
+ &kernels,
+ name,
+ v.len(),
+ last_dim,
+ &input,
+ &mut output,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ output.read_to_vec::<T>(v.len())
+ }
+
+ #[test]
+ fn reduce_sum() {
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
+ let out_length = 1;
+
+ let results = run_reduce(&v, out_length, "fast_sum_float");
+ assert_eq!(approx(results, 4), vec![21.0]);
+ }
+
+ #[test]
+ fn reduce_sum2() {
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
+ let out_length = 2;
+
+ let results = run_reduce(&v, out_length, "fast_sum_float");
+ assert_eq!(approx(results, 4), vec![6.0, 15.0]);
+ }
+
+ #[test]
+ fn softmax() {
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
+ let last_dim = 6;
+ let results = run_softmax(&v, last_dim, "softmax_float");
+ assert_eq!(
+ approx(results, 4),
+ vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
+ );
+
+ let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
+ let last_dim = 6;
+ let results = run_softmax(&v, last_dim, "softmax_float");
+ assert_eq!(
+ approx(results, 4),
+ vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
+ );
+
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
+ let last_dim = 3;
+ let results = run_softmax(&v, last_dim, "softmax_float");
+ assert_eq!(
+ approx(results, 4),
+ vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
+ );
+ }
+
+ fn run_where_cond<I: Clone, T: Clone>(
+ shape: &[usize],
+ cond: &[I],
+ (cond_stride, cond_offset): (Vec<usize>, usize),
+ left_true: &[T],
+ (left_stride, left_offset): (Vec<usize>, usize),
+ right_false: &[T],
+ (_right_stride, _right_offset): (Vec<usize>, usize),
+ name: &'static str,
+ ) -> Vec<T> {
+ let device = device();
+ let kernels = Kernels::new();
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let options = MTLResourceOptions::StorageModeManaged;
+
+ let length = cond.len();
+ let cond = device.new_buffer_with_data(
+ cond.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(cond) as u64,
+ options,
+ );
+ let left = device.new_buffer_with_data(
+ left_true.as_ptr() as *const core::ffi::c_void,
+ (length * core::mem::size_of::<T>()) as u64,
+ options,
+ );
+ let right = device.new_buffer_with_data(
+ right_false.as_ptr() as *const core::ffi::c_void,
+ (length * core::mem::size_of::<T>()) as u64,
+ options,
+ );
+
+ let mut output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
+ call_where_cond_strided(
+ &device,
+ command_buffer,
+ &kernels,
+ name,
+ shape,
+ &cond,
+ (&cond_stride, cond_offset),
+ &left,
+ (&left_stride, left_offset),
+ &right,
+ (&cond_stride, cond_offset),
+ &mut output,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ output.read_to_vec::<T>(length)
+ }
+
+ #[test]
+ fn where_cond() {
+ let shape = vec![6];
+ let cond = vec![0u8, 1, 0, 0, 1, 1];
+ let cond_l = (vec![1], 0);
+ let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
+ let left_l = (vec![1], 0);
+ let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0];
+ let right_l = (vec![1], 0);
+ let results = run_where_cond(
+ &shape,
+ &cond,
+ cond_l,
+ &left_true,
+ left_l,
+ &right_false,
+ right_l,
+ "where_u8_f32",
+ );
+ assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
+ }
+}
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal
new file mode 100644
index 00000000..4dfc46c2
--- /dev/null
+++ b/candle-metal-kernels/src/reduce.metal
@@ -0,0 +1,124 @@
+#include <metal_stdlib>
+using namespace metal;
+
+METAL_FUNC uint get_strided_index(
+ uint idx,
+ constant size_t &num_dims,
+ constant size_t *dims,
+ constant size_t *strides
+) {
+ uint strided_i = 0;
+ for (uint d = 0; d < num_dims; d++) {
+ uint dim_idx = num_dims - 1 - d;
+ strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
+ idx /= dims[dim_idx];
+ }
+ return strided_i;
+}
+
+constant int THREADGROUP_SIZE = 256;
+
+kernel void fast_sum_float(
+ constant size_t &src_numel,
+ constant size_t &el_to_sum_per_block,
+ device const float *src,
+ device float *dst,
+ uint id [[ thread_position_in_grid ]],
+ uint tid [[ thread_index_in_threadgroup ]],
+ uint dst_id [[ threadgroup_position_in_grid ]],
+ uint blockDim [[ threads_per_threadgroup ]]
+) {
+
+ threadgroup float shared_memory[THREADGROUP_SIZE];
+
+ shared_memory[tid] = 0;
+ // Elements summed in this block range from dst_id * el_to_sum_per_block
+ // to (dst_id + 1) * el_to_sum_per_block.
+ size_t start_idx = dst_id * el_to_sum_per_block;
+ size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
+ size_t idx = start_idx + tid;
+
+ while (idx < stop_idx) {
+ // TODO: Fast version for the contiguous case.
+ // size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
+ shared_memory[tid] += src[idx];
+ idx += blockDim;
+ }
+
+ threadgroup_barrier(mem_flags::mem_none);
+
+ // reduction in shared memory
+ for (uint s = blockDim / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ shared_memory[tid] += shared_memory[tid + s];
+ }
+ threadgroup_barrier(mem_flags::mem_none);
+ }
+
+ dst[dst_id] = shared_memory[0];
+}
+
+kernel void softmax_float(
+ constant size_t &src_numel,
+ constant size_t &el_to_sum_per_block,
+ device const float *src,
+ device float *dst,
+ uint id [[ thread_position_in_grid ]],
+ uint tid [[ thread_index_in_threadgroup ]],
+ uint dst_id [[ threadgroup_position_in_grid ]],
+ uint blockDim [[ threads_per_threadgroup ]]
+) {
+
+ threadgroup float shared_memory[THREADGROUP_SIZE];
+
+ shared_memory[tid] = -INFINITY;
+ // Elements summed in this block range from dst_id * el_to_sum_per_block
+ // to (dst_id + 1) * el_to_sum_per_block.
+ size_t start_idx = dst_id * el_to_sum_per_block;
+ size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
+ size_t idx = start_idx + tid;
+
+ while (idx < stop_idx) {
+ // TODO: Fast version for the contiguous case.
+ shared_memory[tid] = max(shared_memory[tid], src[idx]);
+ idx += blockDim;
+ }
+
+ threadgroup_barrier(mem_flags::mem_none);
+
+ // reduction in shared memory
+ for (uint s = blockDim / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
+ }
+ threadgroup_barrier(mem_flags::mem_none);
+ }
+
+ float max = shared_memory[0];
+
+ shared_memory[tid] = 0;
+
+ // Restart
+ idx = start_idx + tid;
+ while (idx < stop_idx) {
+ // TODO: Fast version for the contiguous case.
+ const float val = exp(src[idx] - max);
+ dst[idx] = val;
+ shared_memory[tid] += val;
+ idx += blockDim;
+ }
+ // reduction in shared memory
+ for (uint s = blockDim / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ shared_memory[tid] += shared_memory[tid + s];
+ }
+ threadgroup_barrier(mem_flags::mem_none);
+ }
+
+ const float inv_acc = 1/shared_memory[0];
+ idx = start_idx + tid;
+ while (idx < stop_idx) {
+ dst[idx] *= inv_acc;
+ idx += blockDim;
+ }
+}
diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal
new file mode 100644
index 00000000..0945b355
--- /dev/null
+++ b/candle-metal-kernels/src/ternary.metal
@@ -0,0 +1,57 @@
+#include <metal_stdlib>
+#
+using namespace metal;
+
+METAL_FUNC uint get_strided_index(
+ uint idx,
+ constant size_t &num_dims,
+ constant size_t *dims,
+ constant size_t *strides
+) {
+ uint strided_i = 0;
+ for (uint d = 0; d < num_dims; d++) {
+ uint dim_idx = num_dims - 1 - d;
+ strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
+ idx /= dims[dim_idx];
+ }
+ return strided_i;
+}
+
+
+#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \
+kernel void FN_NAME( \
+ constant size_t &numel, \
+ constant size_t &num_dims, \
+ constant size_t *dims, \
+ constant size_t *strides, \
+ constant size_t *strides_t, \
+ constant size_t *strides_f, \
+ device const ID_TYPENAME *ids, \
+ device const TYPENAME *t, \
+ device const TYPENAME *f, \
+ device TYPENAME *out ,\
+ uint i [[ thread_position_in_grid ]] \
+) { \
+ uint strided_i = get_strided_index(i, num_dims, dims, strides); \
+ uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
+ uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
+ out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \
+} \
+
+// WHERE_OP(float, int64_t, where_i64_f32)
+// WHERE_OP(double, int64_t, where_i64_f64)
+// WHERE_OP(uint8_t, int64_t, where_i64_u8)
+// WHERE_OP(uint32_t, int64_t, where_i64_u32)
+// WHERE_OP(int64_t, int64_t, where_i64_i64)
+//
+// WHERE_OP(float, uint32_t, where_u32_f32)
+// WHERE_OP(double, uint32_t, where_u32_f64)
+// WHERE_OP(uint8_t, uint32_t, where_u32_u8)
+// WHERE_OP(uint32_t, uint32_t, where_u32_u32)
+// WHERE_OP(int64_t, uint32_t, where_u32_i64)
+
+WHERE_OP(float, uint8_t, where_u8_f32)
+// WHERE_OP(double, uint8_t, where_u8_f64)
+// WHERE_OP(uint8_t, uint8_t, where_u8_u8)
+// WHERE_OP(uint32_t, uint8_t, where_u8_u32)
+// WHERE_OP(int64_t, uint8_t, where_u8_i64)
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
new file mode 100644
index 00000000..77de214e
--- /dev/null
+++ b/candle-metal-kernels/src/unary.metal
@@ -0,0 +1,82 @@
+#include <metal_stdlib>
+
+METAL_FUNC uint get_strided_index(
+ uint idx,
+ constant size_t &num_dims,
+ constant size_t *dims,
+ constant size_t *strides
+) {
+ uint strided_i = 0;
+ for (uint d = 0; d < num_dims; d++) {
+ uint dim_idx = num_dims - 1 - d;
+ strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
+ idx /= dims[dim_idx];
+ }
+ return strided_i;
+}
+
+template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
+template <typename T> METAL_FUNC T neg(T in){ return -in; }
+template <typename T> METAL_FUNC T id(T in){ return in; }
+
+
+using namespace metal;
+
+#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
+kernel void FN_NAME( \
+ constant size_t &dim, \
+ device const TYPENAME *input, \
+ device TYPENAME *output, \
+ uint threadgroup_size [[threads_per_threadgroup]], \
+ uint thread_index [[thread_index_in_threadgroup]] \
+) { \
+ const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
+ const size_t start = thread_index * length; \
+ const size_t stop = min(start + length, dim); \
+ for (size_t i = start; i < stop; i++){ \
+ output[i] = TYPENAME(FN(input[i])); \
+ } \
+}\
+kernel void FN_NAME_STRIDED( \
+ constant size_t &dim, \
+ constant size_t &num_dims, \
+ constant size_t *dims, \
+ constant size_t *strides, \
+ device const TYPENAME *input, \
+ device TYPENAME *output, \
+ uint threadgroup_size [[threads_per_threadgroup]], \
+ uint thread_index [[thread_index_in_threadgroup]] \
+) { \
+ const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
+ const size_t start = thread_index * length; \
+ const size_t stop = min(start + length, dim); \
+ for (size_t i = start; i < stop; i++){ \
+ output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \
+ } \
+}
+
+#define UNARY_OP(NAME) \
+UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
+UNARY(NAME, half, NAME##_half, NAME##_half_strided);
+
+#define BFLOAT_UNARY_OP(NAME) \
+UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
+
+
+UNARY_OP(cos)
+UNARY_OP(sin)
+UNARY_OP(sqr)
+UNARY_OP(sqrt)
+UNARY_OP(neg)
+UNARY_OP(exp)
+UNARY(id, float, copy_float, copy_float_strided)
+UNARY(id, half, copy_half, copy_half_strided)
+
+#if __METAL_VERSION__ >= 310
+BFLOAT_UNARY_OP(cos)
+BFLOAT_UNARY_OP(sin)
+BFLOAT_UNARY_OP(sqr)
+BFLOAT_UNARY_OP(sqrt)
+BFLOAT_UNARY_OP(neg)
+BFLOAT_UNARY_OP(exp)
+#endif