summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml3
-rw-r--r--candle-core/Cargo.toml7
-rw-r--r--candle-core/benches/matmul.rs43
-rw-r--r--candle-core/src/device.rs7
-rw-r--r--candle-core/src/metal_backend.rs1240
-rw-r--r--candle-core/src/tensor.rs5
-rw-r--r--candle-examples/Cargo.toml1
-rw-r--r--candle-metal-kernels/Cargo.toml2
-rw-r--r--candle-metal-kernels/src/affine.metal93
-rw-r--r--candle-metal-kernels/src/binary.metal48
-rw-r--r--candle-metal-kernels/src/cast.metal19
-rw-r--r--candle-metal-kernels/src/indexing.metal228
-rw-r--r--candle-metal-kernels/src/lib.rs1045
-rw-r--r--candle-metal-kernels/src/libMetalFlashAttention.metallibbin0 -> 102760 bytes
-rw-r--r--candle-metal-kernels/src/reduce.metal317
-rw-r--r--candle-metal-kernels/src/ternary.metal3
-rw-r--r--candle-metal-kernels/src/tests.rs365
-rw-r--r--candle-metal-kernels/src/unary.metal73
-rw-r--r--candle-metal-kernels/tmp/affine.rs (renamed from candle-metal-kernels/examples/affine.rs)1
-rw-r--r--candle-metal-kernels/tmp/binary.rs (renamed from candle-metal-kernels/examples/binary.rs)0
-rw-r--r--candle-metal-kernels/tmp/cast.rs (renamed from candle-metal-kernels/examples/cast.rs)0
-rw-r--r--candle-metal-kernels/tmp/unary.rs (renamed from candle-metal-kernels/examples/unary.rs)6
-rw-r--r--candle-nn/Cargo.toml3
-rw-r--r--candle-nn/src/ops.rs41
-rw-r--r--candle-transformers/Cargo.toml1
25 files changed, 2775 insertions, 776 deletions
diff --git a/Cargo.toml b/Cargo.toml
index c0e520f3..c13b2bc6 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -32,6 +32,7 @@ accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
clap = { version = "4.2.4", features = ["derive"] }
+criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.9.14", features = ["f16"] }
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"
@@ -61,7 +62,7 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
-metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
+metal = { version = "0.27.0", features = ["mps"]}
[profile.release-with-debug]
inherits = "release"
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index 42e5be2a..52e79a5a 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -34,6 +34,8 @@ zip = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
clap = { workspace = true }
+criterion = { workspace = true }
+
[features]
default = []
@@ -42,3 +44,8 @@ cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"]
+
+[[bench]]
+name = "matmul"
+harness = false
+
diff --git a/candle-core/benches/matmul.rs b/candle-core/benches/matmul.rs
new file mode 100644
index 00000000..8732f451
--- /dev/null
+++ b/candle-core/benches/matmul.rs
@@ -0,0 +1,43 @@
+use candle_core::{DType, Device, Tensor};
+use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};
+use std::time::Instant;
+
+fn run(a: &Tensor, b: &Tensor) {
+ a.matmul(&b.t().unwrap()).unwrap();
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let b = 1;
+ let m = 1;
+ let n = 2048;
+ let k = 2048;
+
+ let device = Device::new_metal(0).unwrap();
+ let dtype = DType::F32;
+ let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap();
+ let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap();
+
+ let flops = b * m * n * k;
+
+ let mut group = c.benchmark_group("matmul_metal");
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.bench_function("iter", move |b| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ run(black_box(&lhs), black_box(&rhs));
+ }
+ if let Device::Metal(device) = &device {
+ device.wait_until_completed().unwrap();
+ } else {
+ panic!("Expected metal device");
+ }
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+criterion_group!(benches, criterion_benchmark);
+criterion_main!(benches);
+
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 3eb7f8b7..1e33021b 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -201,10 +201,9 @@ impl Device {
Ok(Storage::Cuda(storage))
}
}
- Device::Metal(_device) => {
- // let storage = device.rand_uniform(shape, dtype, lo, up)?;
- // Ok(Storage::Metal(storage))
- crate::bail!("Metal rand_uniform not implemented")
+ Device::Metal(device) => {
+ let storage = device.rand_uniform(shape, dtype, lo, up)?;
+ Ok(Storage::Metal(storage))
}
}
}
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index 0b72f080..27b2824f 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -4,11 +4,30 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels;
use candle_metal_kernels::Kernels;
-use core::mem;
-use half::{bf16, f16};
use metal;
-use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger};
-use std::sync::Arc;
+use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
+use std::collections::HashMap;
+use std::path::Path;
+use std::sync::{Arc, RwLock, TryLockError};
+
+/// Simple way to catch lock error without
+/// depending on T
+#[derive(thiserror::Error, Debug)]
+pub enum LockError {
+ #[error("{0}")]
+ Poisoned(String),
+ #[error("Would block")]
+ WouldBlock,
+}
+
+impl<T> From<TryLockError<T>> for MetalError {
+ fn from(value: TryLockError<T>) -> Self {
+ match value {
+ TryLockError::Poisoned(p) => MetalError::LockError(LockError::Poisoned(p.to_string())),
+ TryLockError::WouldBlock => MetalError::LockError(LockError::WouldBlock),
+ }
+ }
+}
/// Metal related errors
#[derive(thiserror::Error, Debug)]
@@ -24,6 +43,14 @@ pub enum MetalError {
rhs_stride: Vec<usize>,
mnk: (usize, usize, usize),
},
+ #[error("{0:?}")]
+ LockError(LockError),
+ #[error("{msg}, expected: {expected:?}, got: {got:?}")]
+ UnexpectedDType {
+ msg: &'static str,
+ expected: DType,
+ got: DType,
+ },
}
impl From<String> for MetalError {
@@ -32,11 +59,53 @@ impl From<String> for MetalError {
}
}
+type AllocatedBuffers = Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>;
+
#[derive(Clone)]
pub struct MetalDevice {
+ /// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
device: metal::Device,
+
+ /// Single command queue for the entire device.
command_queue: metal::CommandQueue,
+ /// One command buffer at a time.
+ /// The scheduler works by allowing multiple
+ /// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
+ /// on a single command buffer. Using a single command buffer would be fastest on the GPU but
+ /// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
+ /// to start to work).
+ /// Despite what the documentation says, command buffers are NOT ordered. They are ordered
+ /// for their START time, but there's no guarantee that command buffer1 will finish before
+ /// command buffer2 starts (or there are metal bugs there)
+ command_buffer: Arc<RwLock<metal::CommandBuffer>>,
+ /// Keeps track of the current amount of compute command encoders on the current
+ /// command buffer
+ /// Arc, RwLock because of the interior mutability.
+ command_buffer_index: Arc<RwLock<usize>>,
+ /// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
+ compute_per_buffer: usize,
+ /// Every compute command encoder (and blit encoders) are defended with this Fence, forcing the
+ /// execution order to be linear.
+ /// It could be relaxed in some circumstances, by managing ourselves the dependencies in the
+ /// compute graph.
+ fence: metal::Fence,
+ /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
+ /// Heavily used by [`candle_metal_kernels`], both fences need to match
kernels: Arc<candle_metal_kernels::Kernels>,
+ /// Simple allocator struct.
+ /// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
+ /// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
+ /// (could be linked to FFI communication overhead).
+ ///
+ /// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
+ /// graph calculation, and only we the allocator kept a reference to it, therefore it's free
+ /// to be reused. However, in order for this to work, we need to guarantee the order of
+ /// operation, so that this buffer is not being used by another kernel at the same time.
+ /// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
+ ///
+ /// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
+ /// (strong_count = 1).
+ buffers: AllocatedBuffers,
}
impl std::fmt::Debug for MetalDevice {
@@ -58,10 +127,47 @@ impl MetalDevice {
self.registry_id()
}
+ pub fn metal_device(&self) -> &metal::Device {
+ &self.device
+ }
+
pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue
}
+ pub fn command_buffer(&self) -> Result<CommandBuffer> {
+ let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
+ let mut command_buffer = command_buffer_lock.to_owned();
+ let mut index = self
+ .command_buffer_index
+ .try_write()
+ .map_err(MetalError::from)?;
+ if *index > self.compute_per_buffer {
+ command_buffer.commit();
+ command_buffer = self.command_queue.new_command_buffer().to_owned();
+ *command_buffer_lock = command_buffer.clone();
+ *index = 0;
+ }
+ *index += 1;
+ Ok(command_buffer)
+ }
+
+ pub fn wait_until_completed(&self) -> Result<()> {
+ let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
+ match command_buffer.status() {
+ metal::MTLCommandBufferStatus::Committed
+ | metal::MTLCommandBufferStatus::Scheduled
+ | metal::MTLCommandBufferStatus::Completed => {
+ panic!("Already committed");
+ }
+ _ => {}
+ }
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+ *command_buffer = self.command_queue.new_command_buffer().to_owned();
+ Ok(())
+ }
+
pub fn kernels(&self) -> &Kernels {
&self.kernels
}
@@ -70,17 +176,119 @@ impl MetalDevice {
&self.device
}
- pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
+ /// Creates a new buffer (not necessarily zeroed).
+ /// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
+ /// This means the buffer data cannot be read on the CPU directly.
+ ///
+ /// [`name`] is only used to keep track of the resource origin in case of bugs
+ pub fn new_buffer(
+ &self,
+ element_count: usize,
+ dtype: DType,
+ name: &str,
+ ) -> Result<Arc<Buffer>> {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
- self.device
- .new_buffer(size, MTLResourceOptions::StorageModeManaged)
+ self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
+ }
+
+ /// Creates a new buffer (not necessarily zeroed).
+ /// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
+ /// This means the buffer can be read on the CPU but will require manual
+ /// synchronization when the CPU memory is modified
+ /// Used as a bridge to gather data back from the GPU
+ pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
+ self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
+ }
+
+ /// Creates a new buffer from data.
+ /// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
+ ///
+ /// This method will block the computation because of the
+ /// lack of lifetime management through the GPU.
+ /// Internal comment for technical details.
+ pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
+ let size = core::mem::size_of_val(data) as NSUInteger;
+ let tmp = self.device.new_buffer_with_data(
+ data.as_ptr() as *const core::ffi::c_void,
+ size,
+ metal::MTLResourceOptions::StorageModeManaged,
+ );
+ let real = self.allocate_buffer(
+ size,
+ metal::MTLResourceOptions::StorageModePrivate,
+ "with_data",
+ )?;
+ let command_buffer = self.command_buffer()?;
+ command_buffer.set_label("with_data");
+ let blit = command_buffer.new_blit_command_encoder();
+ blit.wait_for_fence(&self.fence);
+ blit.set_label("with_data_blit");
+ blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
+ blit.update_fence(&self.fence);
+ blit.end_encoding();
+
+ // This is necessary, for mmaped safetensors
+ // Because of the unsafe slice cast we're doing.
+ // The slice might not live long enough for metal
+ // To actually fill the GPU buffer.
+ // Putting this wait forces the GPU buffer to be filled
+ // with the actual data allowing the CPU storage todo
+ // deallocate properly.
+ self.wait_until_completed()?;
+ Ok(real)
+ }
+
+ /// The critical allocator algorithm
+ fn allocate_buffer(
+ &self,
+ size: NSUInteger,
+ option: MTLResourceOptions,
+ _name: &str,
+ ) -> Result<Arc<Buffer>> {
+ let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
+ let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
+
+ for sub in &mut *subbuffers {
+ if Arc::strong_count(sub) == 1 {
+ return Ok(sub.clone());
+ }
+ }
+ let new_buffer = self.device.new_buffer(size as NSUInteger, option);
+ let new_buffer = Arc::new(new_buffer);
+ subbuffers.push(new_buffer.clone());
+ for subbuffers in buffers.values_mut() {
+ let newbuffers = subbuffers
+ .iter()
+ .filter(|s| Arc::strong_count(s) > 1)
+ .map(Arc::clone)
+ .collect();
+ *subbuffers = newbuffers;
+ }
+ Ok(new_buffer)
+ }
+
+ /// Create a metal GPU capture trace on [`path`].
+ pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
+ let capture = metal::CaptureManager::shared();
+ let descriptor = metal::CaptureDescriptor::new();
+ descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
+ descriptor.set_capture_device(self);
+ descriptor.set_output_url(path);
+
+ capture
+ .start_capture(&descriptor)
+ .map_err(MetalError::from)?;
+ Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MetalStorage {
- buffer: metal::Buffer,
+ /// The actual buffer containing the data.
+ buffer: Arc<metal::Buffer>,
+ /// a reference to the device owning this buffer
device: MetalDevice,
+ /// The dtype is kept since buffers are untyped.
dtype: DType,
}
@@ -108,14 +316,27 @@ impl BackendStorage for MetalStorage {
self.dtype
);
}
+ let buffer = self.device.new_buffer_managed(self.buffer.length())?;
+ {
+ let command_buffer = self.device.command_buffer()?;
+ command_buffer.set_label("to_cpu");
+ let blit = command_buffer.new_blit_command_encoder();
+ blit.set_label("blit_to_cpu");
+ blit.wait_for_fence(&self.device.fence);
+ blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
+ blit.update_fence(&self.device.fence);
+ blit.end_encoding();
+ }
+ self.device.wait_until_completed()?;
+
match self.dtype {
- DType::U8 => Ok(CpuStorage::U8(self.buffer.read_to_vec(length / size))),
- DType::U32 => Ok(CpuStorage::U32(self.buffer.read_to_vec(length / size))),
- DType::I64 => Ok(CpuStorage::I64(self.buffer.read_to_vec(length / size))),
- DType::F16 => Ok(CpuStorage::F16(self.buffer.read_to_vec(length / size))),
- DType::BF16 => Ok(CpuStorage::BF16(self.buffer.read_to_vec(length / size))),
- DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(length / size))),
- DType::F64 => Ok(CpuStorage::F64(self.buffer.read_to_vec(length / size))),
+ DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))),
+ DType::U32 => Ok(CpuStorage::U32(read_to_vec(&buffer, length / size))),
+ DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))),
+ DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))),
+ DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))),
+ DType::F32 => Ok(CpuStorage::F32(read_to_vec(&buffer, length / size))),
+ DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))),
}
}
@@ -126,52 +347,152 @@ impl BackendStorage for MetalStorage {
let el = shape.elem_count();
let dtype = self.dtype;
- if layout.is_contiguous() || layout.start_offset() != 0 || dtype != DType::F32 {
- crate::bail!("Not contiguous, non-f32 affine is not implemented yet.");
+ let buffer = device.new_buffer(el, self.dtype, "affine")?;
+ let command_buffer = self.device.command_buffer()?;
+ if layout.is_contiguous() && layout.start_offset() == 0 {
+ let name = match self.dtype {
+ DType::F32 => "affine_f32",
+ DType::F16 => "affine_f16",
+ dtype => crate::bail!("Affine {dtype:?}"),
+ };
+ candle_metal_kernels::call_affine(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ name,
+ el,
+ &self.buffer,
+ &buffer,
+ mul as f32,
+ add as f32,
+ )
+ .map_err(MetalError::from)?;
+ } else {
+ let name = match self.dtype {
+ DType::F32 => "affine_f32_strided",
+ DType::F16 => "affine_f16_strided",
+ dtype => crate::bail!("Affine {dtype:?}"),
+ };
+ candle_metal_kernels::call_affine_strided(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ name,
+ layout.dims(),
+ &self.buffer,
+ layout.stride(),
+ layout.start_offset() * dtype.size_in_bytes(),
+ &buffer,
+ mul as f32,
+ add as f32,
+ )
+ .map_err(MetalError::from)?;
}
-
- let mut buffer = device.new_buffer(el, self.dtype);
- let command_buffer = self.device.command_queue.new_command_buffer();
- candle_metal_kernels::call_affine(
- &device.device,
- &command_buffer,
- &device.kernels,
- el,
- &self.buffer,
- &mut buffer,
- mul as f32,
- add as f32,
- )
- .map_err(MetalError::from)?;
- command_buffer.commit();
- command_buffer.wait_until_completed();
- return Ok(Self {
- buffer,
- device: device.clone(),
- dtype,
- });
+ Ok(Self::new(buffer, device.clone(), dtype))
}
- fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
- crate::bail!("powf metal")
+ fn powf(&self, layout: &Layout, pow: f64) -> Result<Self> {
+ let device = self.device().clone();
+
+ let shape = layout.shape();
+ let el = shape.elem_count();
+ let dtype = self.dtype;
+
+ let buffer = device.new_buffer(el, self.dtype, "powf")?;
+ let command_buffer = self.device.command_buffer()?;
+ if layout.is_contiguous() && layout.start_offset() == 0 {
+ let name = match self.dtype {
+ DType::F32 => "powf_f32",
+ DType::F16 => "powf_f16",
+ dtype => crate::bail!("Powf {dtype:?}"),
+ };
+ candle_metal_kernels::call_powf(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ name,
+ el,
+ &self.buffer,
+ &buffer,
+ pow as f32,
+ )
+ .map_err(MetalError::from)?;
+ } else {
+ let name = match self.dtype {
+ DType::F32 => "powf_f32_strided",
+ DType::F16 => "powf_f16_strided",
+ dtype => crate::bail!("Powf {dtype:?}"),
+ };
+ candle_metal_kernels::call_powf_strided(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ name,
+ layout.dims(),
+ &self.buffer,
+ layout.stride(),
+ layout.start_offset() * dtype.size_in_bytes(),
+ &buffer,
+ pow as f32,
+ )
+ .map_err(MetalError::from)?;
+ }
+ Ok(Self::new(buffer, device.clone(), dtype))
}
- fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
- crate::bail!("elu metal")
+ fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
+ let device = self.device().clone();
+
+ let shape = layout.shape();
+ let el = shape.elem_count();
+ let dtype = self.dtype;
+
+ let buffer = device.new_buffer(el, self.dtype, "elu")?;
+ let command_buffer = self.device.command_buffer()?;
+ if layout.is_contiguous() && layout.start_offset() == 0 {
+ let name = match self.dtype {
+ DType::F32 => "elu_f32",
+ DType::F16 => "elu_f16",
+ dtype => crate::bail!("Powf {dtype:?}"),
+ };
+ candle_metal_kernels::call_elu(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ name,
+ el,
+ &self.buffer,
+ &buffer,
+ alpha as f32,
+ )
+ .map_err(MetalError::from)?;
+ } else {
+ let name = match self.dtype {
+ DType::F32 => "elu_f32_strided",
+ DType::F16 => "elu_f16_strided",
+ dtype => crate::bail!("Powf {dtype:?}"),
+ };
+ candle_metal_kernels::call_elu_strided(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ name,
+ layout.dims(),
+ &self.buffer,
+ layout.stride(),
+ layout.start_offset() * dtype.size_in_bytes(),
+ &buffer,
+ alpha as f32,
+ )
+ .map_err(MetalError::from)?;
+ }
+ Ok(Self::new(buffer, device.clone(), dtype))
}
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
- if !(sum_dims.len() == 1
- && sum_dims[0] == layout.shape().rank() - 1
- && layout.is_contiguous()
- && layout.start_offset() == 0)
- {
- crate::bail!("Non contiguous reduce op not supported yet");
- }
let device = self.device.clone();
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
- let src_el: usize = src_dims.iter().product();
// Source dims and strides with the sum dims at the end.
let mut dims = vec![];
let mut stride = vec![];
@@ -191,53 +512,77 @@ impl BackendStorage for MetalStorage {
// The reduction loop requires the shared array to be properly initialized and for
// this we want the number of threads to be a power of two.
let (name, check_empty, return_index) = match (op, self.dtype) {
- (ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false),
- (ReduceOp::Min, DType::F32) => ("fast_min_float", true, false),
- (ReduceOp::Max, DType::F32) => ("fast_max_float", true, false),
- (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true),
- (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true),
- _ => crate::bail!("Reduce op for non float"),
+ (ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
+ (ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
+ (ReduceOp::Max, DType::F32) => ("fast_max_f32_strided", true, false),
+ (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32_strided", true, true),
+ (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32_strided", true, true),
+ (ReduceOp::Sum, DType::U32) => ("fast_sum_u32_strided", false, false),
+ (ReduceOp::Min, DType::U32) => ("fast_min_u32_strided", true, false),
+ (ReduceOp::Max, DType::U32) => ("fast_max_u32_strided", true, false),
+ (ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32_strided", true, true),
+ (ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32_strided", true, true),
+ (ReduceOp::Sum, DType::F16) => ("fast_sum_f16_strided", false, false),
+ (ReduceOp::Min, DType::F16) => ("fast_min_f16_strided", true, false),
+ (ReduceOp::Max, DType::F16) => ("fast_max_f16_strided", true, false),
+ (ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16_strided", true, true),
+ (ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16_strided", true, true),
+ (ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16_strided", false, false),
+ (ReduceOp::Min, DType::BF16) => ("fast_min_bf16_strided", true, false),
+ (ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false),
+ (ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true),
+ (ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true),
+ (k, dtype) => crate::bail!("Reduce op for non float {k:?} {dtype:?}"),
};
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let dtype = if return_index { DType::U32 } else { self.dtype };
- let mut buffer = device.new_buffer(dst_el, dtype);
- let command_buffer = self.device.command_queue.new_command_buffer();
- candle_metal_kernels::call_reduce_contiguous(
+ let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
+ let command_buffer = self.device.command_buffer()?;
+ candle_metal_kernels::call_reduce_strided(
&device.device,
&command_buffer,
&device.kernels,
name,
- src_el,
+ &dims,
+ &stride,
dst_el,
&self.buffer,
- &mut buffer,
+ layout.start_offset() * self.dtype.size_in_bytes(),
+ &buffer,
)
.map_err(MetalError::from)?;
- command_buffer.commit();
- command_buffer.wait_until_completed();
- Ok(Self {
- buffer,
- device,
- dtype,
- })
+ Ok(Self::new(buffer, device, dtype))
}
- fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
- crate::bail!("cmp metal")
+ fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
+ let name = match op {
+ CmpOp::Eq => "eq",
+ CmpOp::Ne => "ne",
+ CmpOp::Le => "le",
+ CmpOp::Ge => "ge",
+ CmpOp::Lt => "lt",
+ CmpOp::Gt => "gt",
+ };
+ self.binary(name, rhs, lhs_l, rhs_l)
}
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
let device = self.device();
let shape = layout.shape();
let el_count = shape.elem_count();
- let mut buffer = device.new_buffer(el_count, dtype);
- let command_buffer = device.command_queue.new_command_buffer();
- if layout.is_contiguous() {
+ let buffer = device.new_buffer(el_count, dtype, "todtype")?;
+ let command_buffer = device.command_buffer()?;
+ if layout.is_contiguous() && layout.start_offset() == 0 {
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32",
+ (DType::U32, DType::U8) => "cast_u32_u8",
+ (DType::U8, DType::U32) => "cast_u8_u32",
+ (DType::U8, DType::F32) => "cast_u8_f32",
+ (DType::F32, DType::F16) => "cast_f32_f16",
+ (DType::F16, DType::F32) => "cast_f16_f32",
(left, right) => crate::bail!("to dtype {left:?} - {right:?}"),
};
candle_metal_kernels::call_cast_contiguous(
@@ -247,24 +592,35 @@ impl BackendStorage for MetalStorage {
kernel_name,
el_count,
&self.buffer,
- &mut buffer,
+ layout.start_offset() * self.dtype.size_in_bytes(),
+ &buffer,
)
.map_err(MetalError::from)?;
} else {
- crate::bail!(
- "TODO Implement the kernel calling cast {:?}-{:?}",
- self.dtype,
- dtype
- );
+ let kernel_name = match (self.dtype, dtype) {
+ (DType::U32, DType::F32) => "cast_u32_f32_strided",
+ (DType::U32, DType::U8) => "cast_u32_u8_strided",
+ (DType::U8, DType::U32) => "cast_u8_u32_strided",
+ (DType::U8, DType::F32) => "cast_u8_f32_strided",
+ (DType::F32, DType::F16) => "cast_f32_f16_strided",
+ (DType::F16, DType::F32) => "cast_f16_f32_strided",
+ (left, right) => crate::bail!("to dtype {left:?} - {right:?}"),
+ };
+ candle_metal_kernels::call_cast_strided(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ kernel_name,
+ layout.dims(),
+ &self.buffer,
+ layout.stride(),
+ layout.start_offset() * self.dtype.size_in_bytes(),
+ &buffer,
+ )
+ .map_err(MetalError::from)?;
}
-
- command_buffer.commit();
- command_buffer.wait_until_completed();
- Ok(Self {
- buffer,
- device: device.clone(),
- dtype,
- })
+ command_buffer.set_label("to_dtype");
+ Ok(Self::new(buffer, device.clone(), dtype))
}
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
@@ -272,8 +628,9 @@ impl BackendStorage for MetalStorage {
let dtype = self.dtype;
let shape = layout.shape();
let el_count = shape.elem_count();
- let mut buffer = device.new_buffer(el_count, dtype);
- let command_buffer = device.command_queue.new_command_buffer();
+ let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
+ let command_buffer = device.command_buffer()?;
+ command_buffer.set_label(B::KERNEL);
if layout.is_contiguous() && layout.start_offset() == 0 {
use candle_metal_kernels::unary::contiguous;
@@ -285,6 +642,27 @@ impl BackendStorage for MetalStorage {
("uneg", DType::F32) => contiguous::neg::FLOAT,
("uexp", DType::F32) => contiguous::exp::FLOAT,
("ulog", DType::F32) => contiguous::log::FLOAT,
+ ("ugelu", DType::F32) => contiguous::gelu::FLOAT,
+ ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
+ ("uerf", DType::F32) => contiguous::erf::FLOAT,
+ ("uceil", DType::F32) => contiguous::ceil::FLOAT,
+ ("ufloor", DType::F32) => contiguous::floor::FLOAT,
+ ("uround", DType::F32) => contiguous::round::FLOAT,
+ ("utanh", DType::F32) => contiguous::tanh::FLOAT,
+ ("ucos", DType::F16) => contiguous::cos::HALF,
+ ("usin", DType::F16) => contiguous::sin::HALF,
+ ("usqr", DType::F16) => contiguous::sqr::HALF,
+ ("usqrt", DType::F16) => contiguous::sqrt::HALF,
+ ("uneg", DType::F16) => contiguous::neg::HALF,
+ ("uexp", DType::F16) => contiguous::exp::HALF,
+ ("ulog", DType::F16) => contiguous::log::HALF,
+ ("ugelu", DType::F16) => contiguous::gelu::HALF,
+ ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
+ ("uerf", DType::F16) => contiguous::erf::HALF,
+ ("uceil", DType::F16) => contiguous::ceil::HALF,
+ ("ufloor", DType::F16) => contiguous::floor::HALF,
+ ("uround", DType::F16) => contiguous::round::HALF,
+ ("utanh", DType::F16) => contiguous::tanh::HALF,
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_unary_contiguous(
@@ -294,95 +672,64 @@ impl BackendStorage for MetalStorage {
kernel_name,
el_count,
&self.buffer,
- &mut buffer,
- )
- .map_err(MetalError::from)?;
- } else {
- crate::bail!("TODO Implement the kernel calling {}", B::KERNEL);
- }
- command_buffer.commit();
- command_buffer.wait_until_completed();
-
- Ok(Self {
- buffer,
- device: device.clone(),
- dtype,
- })
- }
-
- fn binary_impl<B: BinaryOpT>(
- &self,
- rhs: &Self,
- lhs_l: &Layout,
- rhs_l: &Layout,
- ) -> Result<Self> {
- let device = self.device();
- let dtype = self.dtype;
- let shape = lhs_l.shape();
- let el_count = shape.elem_count();
- let mut buffer = device.new_buffer(el_count, dtype);
- let command_buffer = device.command_queue.new_command_buffer();
- if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
- && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
- {
- use candle_metal_kernels::binary::contiguous;
-
- let kernel_name = match (B::KERNEL, dtype) {
- ("add", DType::F32) => contiguous::add::FLOAT,
- ("badd", DType::F32) => contiguous::add::FLOAT,
- ("sub", DType::F32) => contiguous::sub::FLOAT,
- ("bsub", DType::F32) => contiguous::sub::FLOAT,
- ("mul", DType::F32) => contiguous::mul::FLOAT,
- ("bmul", DType::F32) => contiguous::mul::FLOAT,
- ("div", DType::F32) => contiguous::div::FLOAT,
- ("bdiv", DType::F32) => contiguous::div::FLOAT,
- (name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
- };
- candle_metal_kernels::call_binary_contiguous(
- &device.device,
- &command_buffer,
- &device.kernels,
- kernel_name,
- el_count,
- &self.buffer,
- &rhs.buffer,
- &mut buffer,
+ &buffer,
)
.map_err(MetalError::from)?;
} else {
- use candle_metal_kernels::binary::strided;
-
+ use candle_metal_kernels::unary::strided;
let kernel_name = match (B::KERNEL, dtype) {
- ("badd", DType::F32) => strided::add::FLOAT,
- ("bsub", DType::F32) => strided::sub::FLOAT,
- ("bmul", DType::F32) => strided::mul::FLOAT,
- ("bdiv", DType::F32) => strided::div::FLOAT,
+ ("ucos", DType::F32) => strided::cos::FLOAT,
+ ("usin", DType::F32) => strided::sin::FLOAT,
+ ("usqr", DType::F32) => strided::sqr::FLOAT,
+ ("usqrt", DType::F32) => strided::sqrt::FLOAT,
+ ("uneg", DType::F32) => strided::neg::FLOAT,
+ ("uexp", DType::F32) => strided::exp::FLOAT,
+ ("ulog", DType::F32) => strided::log::FLOAT,
+ ("ugelu", DType::F32) => strided::gelu::FLOAT,
+ ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
+ ("uerf", DType::F32) => strided::erf::FLOAT,
+ ("uceil", DType::F32) => strided::ceil::FLOAT,
+ ("ufloor", DType::F32) => strided::floor::FLOAT,
+ ("uround", DType::F32) => strided::round::FLOAT,
+ ("ucos", DType::F16) => strided::cos::HALF,
+ ("usin", DType::F16) => strided::sin::HALF,
+ ("usqr", DType::F16) => strided::sqr::HALF,
+ ("usqrt", DType::F16) => strided::sqrt::HALF,
+ ("uneg", DType::F16) => strided::neg::HALF,
+ ("uexp", DType::F16) => strided::exp::HALF,
+ ("ulog", DType::F16) => strided::log::HALF,
+ ("ugelu", DType::F16) => strided::gelu::HALF,
+ ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
+ ("uerf", DType::F16) => strided::erf::HALF,
+ ("uceil", DType::F16) => strided::ceil::HALF,
+ ("ufloor", DType::F16) => strided::floor::HALF,
+ ("uround", DType::F16) => strided::round::HALF,
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
};
- candle_metal_kernels::call_binary_strided(
+ candle_metal_kernels::call_unary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
- lhs_l.dims(),
+ layout.dims(),
&self.buffer,
- &lhs_l.stride(),
- lhs_l.start_offset() * self.dtype.size_in_bytes(),
- &rhs.buffer,
- &rhs_l.stride(),
- rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
- &mut buffer,
+ layout.stride(),
+ layout.start_offset() * self.dtype.size_in_bytes(),
+ &buffer,
+ 0,
)
.map_err(MetalError::from)?;
}
- command_buffer.commit();
- command_buffer.wait_until_completed();
+ Ok(Self::new(buffer, device.clone(), dtype))
+ }
- Ok(Self {
- buffer,
- device: device.clone(),
- dtype,
- })
+ fn binary_impl<B: BinaryOpT>(
+ &self,
+ rhs: &Self,
+ lhs_l: &Layout,
+ rhs_l: &Layout,
+ ) -> Result<Self> {
+ self.binary(B::KERNEL, rhs, lhs_l, rhs_l)
}
fn where_cond(
@@ -398,14 +745,26 @@ impl BackendStorage for MetalStorage {
let dims = shape.dims();
let el = shape.elem_count();
let dtype = t.dtype;
- let mut buffer = self.device.new_buffer(el, dtype);
- let command_buffer = self.device.command_queue.new_command_buffer();
+ let buffer = self.device.new_buffer(el, dtype, "where")?;
+ let command_buffer = self.device.command_buffer()?;
+ if t.dtype() != f.dtype() {
+ crate::bail!(
+ "Invalid where: different dtypes for values {:?} != {:?}",
+ t.dtype(),
+ f.dtype()
+ );
+ }
+ let name = match (self.dtype, t.dtype()) {
+ (DType::U8, DType::F32) => "where_u8_f32",
+ (DType::U8, DType::F16) => "where_u8_f16",
+ (left, right) => crate::bail!("where {left:?} - {right:?} not implemented"),
+ };
candle_metal_kernels::call_where_cond_strided(
&device.device,
&command_buffer,
&device.kernels,
- "where_u8_f32",
- &dims,
+ name,
+ dims,
&self.buffer,
(
layout.stride(),
@@ -415,16 +774,10 @@ impl BackendStorage for MetalStorage {
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
&f.buffer,
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
- &mut buffer,
+ &buffer,
)
.map_err(MetalError::from)?;
- command_buffer.commit();
- command_buffer.wait_until_completed();
- Ok(Self {
- buffer,
- device,
- dtype,
- })
+ Ok(Self::new(buffer, device, dtype))
}
fn conv1d(
@@ -483,20 +836,84 @@ impl BackendStorage for MetalStorage {
crate::bail!("upsample_nearest2d metal")
}
- fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
- crate::bail!("gather metal")
+ fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
+ let (ids_o1, _) = match ids_l.contiguous_offsets() {
+ Some(o12) => o12,
+ None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
+ };
+ let ids_el = ids_l.dims()[dim];
+ let dst_el = ids_l.shape().elem_count();
+ let dtype = self.dtype;
+ let device = self.device();
+ let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
+ let name = match (ids.dtype, self.dtype) {
+ (DType::U32, DType::F32) => "gather_u32_f32",
+ (DType::U32, DType::F16) => "gather_u32_f16",
+ (left, right) => crate::bail!("gather metal {left:?} {right:?} not implemented"),
+ };
+ let command_buffer = self.device.command_buffer()?;
+ candle_metal_kernels::call_gather(
+ &device.device,
+ &command_buffer,
+ &self.device.kernels,
+ name,
+ src_l.dims(),
+ ids_el,
+ dim,
+ &self.buffer,
+ src_l.start_offset() * dtype.size_in_bytes(),
+ &ids.buffer,
+ ids_o1 * ids.dtype.size_in_bytes(),
+ &buffer,
+ )
+ .map_err(MetalError::from)?;
+ Ok(Self::new(buffer, device.clone(), dtype))
}
fn scatter_add(
&self,
- _: &Layout,
- _: &Self,
- _: &Layout,
- _: &Self,
- _: &Layout,
- _: usize,
+ l: &Layout,
+ ids: &Self,
+ ids_l: &Layout,
+ src: &Self,
+ src_l: &Layout,
+ dim: usize,
) -> Result<Self> {
- crate::bail!("scatter_add metal")
+ let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
+ self.copy_strided_src(&mut acc, 0, l)?;
+ let (ids_offset, _) = match ids_l.contiguous_offsets() {
+ Some(o12) => o12,
+ None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
+ };
+ let src_offset = match src_l.contiguous_offsets() {
+ Some((o1, _)) => o1,
+ None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
+ };
+ let name = match (ids.dtype, self.dtype) {
+ (DType::U32, DType::F32) => "sa_u32_f32",
+ _ => Err(MetalError::UnexpectedDType {
+ msg: "scatter-add ids should be u8/u32/i64",
+ expected: DType::U32,
+ got: ids.dtype(),
+ })?,
+ };
+ let command_buffer = self.device.command_buffer()?;
+ candle_metal_kernels::call_scatter_add(
+ &self.device.device,
+ &command_buffer,
+ &self.device.kernels,
+ name,
+ src_l.dims(),
+ l.dims(),
+ dim,
+ &src.buffer,
+ src_offset * src.dtype.size_in_bytes(),
+ &ids.buffer,
+ ids_offset * ids.dtype.size_in_bytes(),
+ &acc.buffer,
+ )
+ .map_err(MetalError::from)?;
+ Ok(acc)
}
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
@@ -513,12 +930,13 @@ impl BackendStorage for MetalStorage {
let dst_el = ids_el * left_size * right_size;
let dtype = self.dtype;
let device = self.device();
- let mut buffer = device.new_buffer(dst_el, dtype);
+ let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "is_u32_f32",
+ (DType::U32, DType::F16) => "is_u32_f16",
(left, right) => crate::bail!("index select metal {left:?} {right:?}"),
};
- let command_buffer = self.device.command_queue.new_command_buffer();
+ let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_index_select(
&device.device,
&command_buffer,
@@ -529,30 +947,58 @@ impl BackendStorage for MetalStorage {
dim,
&self.buffer,
&ids.buffer,
- &mut buffer,
+ &buffer,
)
.map_err(MetalError::from)?;
- command_buffer.commit();
- command_buffer.wait_until_completed();
- Ok(Self {
- buffer,
- device: device.clone(),
- dtype,
- })
+ Ok(Self::new(buffer, device.clone(), dtype))
}
fn index_add(
&self,
- _: &Layout,
- _: &Self,
- _: &Layout,
- _: &Self,
- _: &Layout,
- _: usize,
+ l: &Layout,
+ ids: &Self,
+ ids_l: &Layout,
+ src: &Self,
+ src_l: &Layout,
+ dim: usize,
) -> Result<Self> {
- crate::bail!("index_add metal")
+ let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
+ self.copy_strided_src(&mut acc, 0, l)?;
+ let (ids_offset, _) = match ids_l.contiguous_offsets() {
+ Some(o12) => o12,
+ None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
+ };
+ let src_offset = match src_l.contiguous_offsets() {
+ Some((o1, _)) => o1,
+ None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
+ };
+ let name = match (ids.dtype, self.dtype) {
+ (DType::U32, DType::F32) => "ia_u32_f32",
+ _ => Err(MetalError::UnexpectedDType {
+ msg: "index-add ids should be u8/u32/i64",
+ expected: DType::U32,
+ got: ids.dtype(),
+ })?,
+ };
+ let command_buffer = self.device.command_buffer()?;
+ candle_metal_kernels::call_index_add(
+ &self.device.device,
+ &command_buffer,
+ &self.device.kernels,
+ name,
+ src_l.dims(),
+ l.dims(),
+ ids_l.dims(),
+ dim,
+ &src.buffer,
+ src_offset * src.dtype.size_in_bytes(),
+ &ids.buffer,
+ ids_offset * ids.dtype.size_in_bytes(),
+ &acc.buffer,
+ )
+ .map_err(MetalError::from)?;
+ Ok(acc)
}
-
fn matmul(
&self,
rhs: &Self,
@@ -560,147 +1006,81 @@ impl BackendStorage for MetalStorage {
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
- // Create descriptors
- use metal::mps::matrix::*;
- let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32;
- let size = core::mem::size_of::<f32>() as NSUInteger;
-
- let elem_count = b * m * n;
-
- let lhs_stride = lhs_l.stride();
- let rhs_stride = rhs_l.stride();
- let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
- let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
- let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
- let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
- // The a tensor has dims batching, k, n (rhs)
- let transpose_left = if lhs_m1 == 1 && lhs_m2 == k {
- false
- } else if lhs_m1 == m && lhs_m2 == 1 {
- true
- } else {
- Err(MetalError::MatMulNonContiguous {
- lhs_stride: lhs_stride.to_vec(),
- rhs_stride: rhs_stride.to_vec(),
- mnk: (m, n, k),
- })?
- };
- let transpose_right = if rhs_m1 == 1 && rhs_m2 == n {
- false
- } else if rhs_m1 == k && rhs_m2 == 1 {
- true
- } else {
- Err(MetalError::MatMulNonContiguous {
- lhs_stride: lhs_stride.to_vec(),
- rhs_stride: rhs_stride.to_vec(),
- mnk: (m, n, k),
- })?
- };
-
- let b = b as NSUInteger;
- let m = m as NSUInteger;
- let n = n as NSUInteger;
- let k = k as NSUInteger;
-
- let left_descriptor = if transpose_left {
- MatrixDescriptor::init_single(k, m, m * size, type_id)
- } else {
- MatrixDescriptor::init_single(m, k, k * size, type_id)
- };
- let right_descriptor = if transpose_right {
- MatrixDescriptor::init_single(n, k, k * size, type_id)
- } else {
- MatrixDescriptor::init_single(k, n, n * size, type_id)
+ let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
+ let name = match self.dtype {
+ DType::F32 => "sgemm",
+ DType::F16 => "hgemm",
+ dtype => {
+ return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into())
+ }
};
- let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
-
- // Create matrix objects
- let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, 0, &left_descriptor)
- .ok_or_else(|| {
- MetalError::from("Failed to create matrix multiplication kernel".to_string())
- })?;
- let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, 0, &right_descriptor)
- .ok_or_else(|| {
- MetalError::from("Failed to create matrix multiplication kernel".to_string())
- })?;
-
- let out_buffer = self.device.new_buffer(elem_count, self.dtype);
- let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor)
- .ok_or_else(|| {
- MetalError::from("Failed to create matrix multiplication kernel".to_string())
- })?;
-
- let alpha = 1.0f64;
- let beta = 0.0f64;
- // Create kernel
- let matrix_multiplication = MatrixMultiplication::init(
- &self.device,
- transpose_left,
- transpose_right,
- m,
- n,
- k,
- alpha,
- beta,
- )
- .ok_or_else(|| {
- MetalError::from("Failed to create matrix multiplication kernel".to_string())
- })?;
-
- matrix_multiplication.set_batch_size(b);
-
- // Encode kernel to command buffer
- let command_buffer = self.device.command_queue.new_command_buffer();
- matrix_multiplication.encode_to_command_buffer(
- command_buffer,
- &left_matrix,
- &right_matrix,
- &result_matrix,
- );
- command_buffer.commit();
- command_buffer.wait_until_completed();
-
- Ok(Self {
- buffer: out_buffer,
- device: self.device.clone(),
- dtype: self.dtype(),
- })
- }
- fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
- let src_shape = src_l.shape();
- let el_count = src_shape.elem_count();
- if el_count == 0 {
- return Ok(());
- }
- let command_buffer = self.device.command_queue.new_command_buffer();
- let kernel_name = match self.dtype {
- DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
- DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
- DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
- dtype => crate::bail!("copy_strided not implemented for {dtype:?}"),
- };
- candle_metal_kernels::call_unary_strided(
+ let command_buffer = self.device.command_buffer()?;
+ command_buffer.set_label("matmul");
+ candle_metal_kernels::call_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
- kernel_name,
- src_l.dims(),
+ name,
+ (b, m, n, k),
+ lhs_l.stride(),
+ lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
- &src_l.stride(),
- src_l.start_offset() * self.dtype.size_in_bytes(),
- &mut dst.buffer,
- dst_offset,
+ rhs_l.stride(),
+ rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
+ &rhs.buffer,
+ &buffer,
)
.map_err(MetalError::from)?;
- command_buffer.commit();
- command_buffer.wait_until_completed();
+ Ok(Self::new(buffer, self.device.clone(), self.dtype()))
+ }
+
+ fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
+ let command_buffer = self.device.command_buffer()?;
+ if src_l.is_contiguous() && self.dtype == dst.dtype() {
+ command_buffer.set_label("copy_contiguous");
+ let blit = command_buffer.new_blit_command_encoder();
+ blit.set_label("copy_contiguous");
+ let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
+ let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger;
+ let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
+ blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
+ blit.end_encoding();
+ } else {
+ let src_shape = src_l.shape();
+ let el_count = src_shape.elem_count();
+ if el_count == 0 {
+ return Ok(());
+ }
+ let kernel_name = match self.dtype {
+ DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
+ DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
+ DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
+ DType::U32 => candle_metal_kernels::unary::strided::copy::U32,
+ DType::U8 => candle_metal_kernels::unary::strided::copy::U8,
+ dtype => crate::bail!("copy_strided not implemented for {dtype:?}"),
+ };
+ candle_metal_kernels::call_unary_strided(
+ &self.device.device,
+ &command_buffer,
+ &self.device.kernels,
+ kernel_name,
+ src_l.dims(),
+ &self.buffer,
+ src_l.stride(),
+ src_l.start_offset() * self.dtype.size_in_bytes(),
+ &dst.buffer,
+ dst_offset * dst.dtype.size_in_bytes(),
+ )
+ .map_err(MetalError::from)?;
+ command_buffer.set_label("copy_strided");
+ }
Ok(())
}
}
impl MetalStorage {
- pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self {
+ pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self {
Self {
buffer,
device,
@@ -711,6 +1091,111 @@ impl MetalStorage {
pub fn buffer(&self) -> &Buffer {
&self.buffer
}
+
+ pub fn binary(
+ &self,
+ op: &'static str,
+ rhs: &Self,
+ lhs_l: &Layout,
+ rhs_l: &Layout,
+ ) -> Result<Self> {
+ let device = self.device();
+ let shape = lhs_l.shape();
+ let el_count = shape.elem_count();
+ let command_buffer = device.command_buffer()?;
+ let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
+ && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
+ && &op[..1] != "b"
+ {
+ use candle_metal_kernels::binary::contiguous;
+
+ let (kernel_name, dtype) = match (op, self.dtype) {
+ ("add", DType::F32) => (contiguous::add::FLOAT, self.dtype),
+ ("sub", DType::F32) => (contiguous::sub::FLOAT, self.dtype),
+ ("mul", DType::F32) => (contiguous::mul::FLOAT, self.dtype),
+ ("div", DType::F32) => (contiguous::div::FLOAT, self.dtype),
+ ("eq", DType::F32) => (contiguous::eq::FLOAT, DType::U8),
+ ("ne", DType::F32) => (contiguous::ne::FLOAT, DType::U8),
+ ("le", DType::F32) => (contiguous::le::FLOAT, DType::U8),
+ ("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8),
+ ("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8),
+ ("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8),
+ ("add", DType::F16) => (contiguous::add::HALF, self.dtype),
+ ("sub", DType::F16) => (contiguous::sub::HALF, self.dtype),
+ ("mul", DType::F16) => (contiguous::mul::HALF, self.dtype),
+ ("div", DType::F16) => (contiguous::div::HALF, self.dtype),
+ ("eq", DType::F16) => (contiguous::eq::HALF, DType::U8),
+ ("ne", DType::F16) => (contiguous::ne::HALF, DType::U8),
+ ("le", DType::F16) => (contiguous::le::HALF, DType::U8),
+ ("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
+ ("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
+ ("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
+ (name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
+ };
+ let buffer = device.new_buffer(el_count, dtype, op)?;
+ candle_metal_kernels::call_binary_contiguous(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ kernel_name,
+ el_count,
+ &self.buffer,
+ &rhs.buffer,
+ &buffer,
+ )
+ .map_err(MetalError::from)?;
+ (buffer, dtype)
+ } else {
+ use candle_metal_kernels::binary::strided;
+
+ let (kernel_name, dtype) = match (op, self.dtype) {
+ ("badd", DType::F32) => (strided::add::FLOAT, self.dtype),
+ ("bsub", DType::F32) => (strided::sub::FLOAT, self.dtype),
+ ("bmul", DType::F32) => (strided::mul::FLOAT, self.dtype),
+ ("bdiv", DType::F32) => (strided::div::FLOAT, self.dtype),
+ ("bminimum", DType::F32) => (strided::min::FLOAT, self.dtype),
+ ("bmaximum", DType::F32) => (strided::max::FLOAT, self.dtype),
+ ("eq", DType::F32) => (strided::eq::FLOAT, DType::U8),
+ ("ne", DType::F32) => (strided::ne::FLOAT, DType::U8),
+ ("le", DType::F32) => (strided::le::FLOAT, DType::U8),
+ ("lt", DType::F32) => (strided::lt::FLOAT, DType::U8),
+ ("ge", DType::F32) => (strided::ge::FLOAT, DType::U8),
+ ("gt", DType::F32) => (strided::gt::FLOAT, DType::U8),
+ ("badd", DType::F16) => (strided::add::HALF, self.dtype),
+ ("bsub", DType::F16) => (strided::sub::HALF, self.dtype),
+ ("bmul", DType::F16) => (strided::mul::HALF, self.dtype),
+ ("bdiv", DType::F16) => (strided::div::HALF, self.dtype),
+ ("bminimum", DType::F16) => (strided::min::HALF, self.dtype),
+ ("bmaximum", DType::F16) => (strided::max::HALF, self.dtype),
+ ("eq", DType::F16) => (strided::eq::HALF, DType::U8),
+ ("ne", DType::F16) => (strided::ne::HALF, DType::U8),
+ ("le", DType::F16) => (strided::le::HALF, DType::U8),
+ ("lt", DType::F16) => (strided::lt::HALF, DType::U8),
+ ("ge", DType::F16) => (strided::ge::HALF, DType::U8),
+ ("gt", DType::F16) => (strided::gt::HALF, DType::U8),
+ (name, dtype) => crate::bail!("Binary strided {name} - {dtype:?} not implemented"),
+ };
+ let buffer = device.new_buffer(el_count, dtype, op)?;
+ candle_metal_kernels::call_binary_strided(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ kernel_name,
+ lhs_l.dims(),
+ &self.buffer,
+ lhs_l.stride(),
+ lhs_l.start_offset() * self.dtype.size_in_bytes(),
+ &rhs.buffer,
+ rhs_l.stride(),
+ rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
+ &buffer,
+ )
+ .map_err(MetalError::from)?;
+ (buffer, dtype)
+ };
+ command_buffer.set_label("binary");
+ Ok(Self::new(buffer, device.clone(), dtype))
+ }
}
impl BackendDevice for MetalDevice {
@@ -718,12 +1203,26 @@ impl BackendDevice for MetalDevice {
fn new(ordinal: usize) -> Result<Self> {
let device = metal::Device::all().swap_remove(ordinal);
-
let command_queue = device.new_command_queue();
- let kernels = Arc::new(Kernels::new());
+ let command_buffer = command_queue.new_command_buffer().to_owned();
+ command_buffer.enqueue();
+ let command_buffer = Arc::new(RwLock::new(command_buffer));
+ let command_buffer_index = Arc::new(RwLock::new(0));
+ let fence = device.new_fence();
+ let kernels = Arc::new(Kernels::new(fence.clone()));
+ let buffers = Arc::new(RwLock::new(HashMap::new()));
+ let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
+ Ok(val) => val.parse()?,
+ _ => 20,
+ };
Ok(Self {
device,
+ fence,
command_queue,
+ command_buffer,
+ command_buffer_index,
+ compute_per_buffer,
+ buffers,
kernels,
})
}
@@ -743,9 +1242,22 @@ impl BackendDevice for MetalDevice {
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
- // TODO Is there a faster way ?
- let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?;
- self.storage_from_cpu_storage(&cpu_storage)
+ let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?;
+ let command_buffer = self.command_buffer()?;
+ command_buffer.set_label("zeros");
+ let blit = command_buffer.new_blit_command_encoder();
+ blit.wait_for_fence(&self.fence);
+ blit.fill_buffer(
+ &buffer,
+ metal::NSRange {
+ location: 0,
+ length: buffer.length(),
+ },
+ 0,
+ );
+ blit.update_fence(&self.fence);
+ blit.end_encoding();
+ Ok(MetalStorage::new(buffer, self.clone(), dtype))
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
@@ -755,49 +1267,16 @@ impl BackendDevice for MetalDevice {
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
- let option = metal::MTLResourceOptions::StorageModeManaged;
let buffer = match storage {
- CpuStorage::U8(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<u8>()) as NSUInteger,
- option,
- ),
- CpuStorage::U32(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<u32>()) as NSUInteger,
- option,
- ),
- CpuStorage::I64(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<i64>()) as NSUInteger,
- option,
- ),
- CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<bf16>()) as NSUInteger,
- option,
- ),
- CpuStorage::F16(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<f16>()) as NSUInteger,
- option,
- ),
- CpuStorage::F32(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<f32>()) as NSUInteger,
- option,
- ),
- CpuStorage::F64(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<f64>()) as NSUInteger,
- option,
- ),
- };
- Ok(Self::Storage {
- buffer,
- device: self.clone(),
- dtype: storage.dtype(),
- })
+ CpuStorage::U8(storage) => self.new_buffer_with_data(storage),
+ CpuStorage::U32(storage) => self.new_buffer_with_data(storage),
+ CpuStorage::I64(storage) => self.new_buffer_with_data(storage),
+ CpuStorage::BF16(storage) => self.new_buffer_with_data(storage),
+ CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
+ CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
+ CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
+ }?;
+ Ok(Self::Storage::new(buffer, self.clone(), storage.dtype()))
}
fn rand_uniform(
@@ -824,3 +1303,10 @@ impl BackendDevice for MetalDevice {
self.storage_from_cpu_storage(&cpu_storage)
}
}
+
+fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
+ let ptr = buffer.contents() as *const T;
+ assert!(!ptr.is_null());
+ let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
+ slice.to_vec()
+}
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index e6e7b415..f15f8c1c 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1877,10 +1877,7 @@ impl Tensor {
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
}
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
- (Storage::Metal(storage), Device::Cpu) => {
- println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
- Storage::Cpu(storage.to_cpu_storage()?)
- }
+ (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
// are the same.
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index 508f75f5..0c4bf20e 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -57,6 +57,7 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
+metal = ["candle/metal", "candle-nn/metal"]
[[example]]
name = "llama_multiprocess"
diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml
index c0e019f4..7ab45a90 100644
--- a/candle-metal-kernels/Cargo.toml
+++ b/candle-metal-kernels/Cargo.toml
@@ -10,7 +10,7 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
-metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
+metal = { version = "0.27.0", features = ["mps"]}
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"
diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal
index e5f0a841..4166d811 100644
--- a/candle-metal-kernels/src/affine.metal
+++ b/candle-metal-kernels/src/affine.metal
@@ -29,15 +29,96 @@ kernel void FN_NAME( \
if (id >= dim) { \
return; \
} \
- const TYPENAME m = TYPENAME(mul); \
- const TYPENAME a = TYPENAME(add); \
- output[id] = input[id] * m + a; \
+ output[id] = TYPENAME(float(input[id]) * mul + add); \
} \
+kernel void FN_NAME##_strided( \
+ constant size_t &dim, \
+ constant size_t &num_dims, \
+ constant size_t *dims, \
+ constant size_t *strides, \
+ constant float &mul, \
+ constant float &add, \
+ device const TYPENAME *input, \
+ device TYPENAME *output, \
+ uint id [[ thread_position_in_grid ]] \
+) { \
+ if (id >= dim) { \
+ return; \
+ } \
+ output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \
+}
+
+#define POWF(FN_NAME, TYPENAME) \
+kernel void FN_NAME( \
+ constant size_t &dim, \
+ constant float &mul, \
+ device const TYPENAME *input, \
+ device TYPENAME *output, \
+ uint id [[ thread_position_in_grid ]] \
+) { \
+ if (id >= dim) { \
+ return; \
+ } \
+ output[id] = TYPENAME(pow(input[id], TYPENAME(mul))); \
+} \
+kernel void FN_NAME##_strided( \
+ constant size_t &dim, \
+ constant size_t &num_dims, \
+ constant size_t *dims, \
+ constant size_t *strides, \
+ constant float &mul, \
+ device const TYPENAME *input, \
+ device TYPENAME *output, \
+ uint id [[ thread_position_in_grid ]] \
+) { \
+ if (id >= dim) { \
+ return; \
+ } \
+ output[id] = TYPENAME(pow(input[get_strided_index(id, num_dims, dims, strides)], TYPENAME(mul))); \
+}
+
+#define ELU(FN_NAME, TYPENAME) \
+kernel void FN_NAME( \
+ constant size_t &dim, \
+ constant float &mul, \
+ device const TYPENAME *input, \
+ device TYPENAME *output, \
+ uint id [[ thread_position_in_grid ]] \
+) { \
+ if (id >= dim) { \
+ return; \
+ } \
+ const TYPENAME x = input[id]; \
+ output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
+} \
+kernel void FN_NAME##_strided( \
+ constant size_t &dim, \
+ constant size_t &num_dims, \
+ constant size_t *dims, \
+ constant size_t *strides, \
+ constant float &mul, \
+ device const TYPENAME *input, \
+ device TYPENAME *output, \
+ uint id [[ thread_position_in_grid ]] \
+) { \
+ if (id >= dim) { \
+ return; \
+ } \
+ const TYPENAME x = input[get_strided_index(id, num_dims, dims, strides)]; \
+ output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
+} \
+
-AFFINE(affine_float, float)
-AFFINE(affine_half, half)
+AFFINE(affine_f32, float)
+AFFINE(affine_f16, half)
+POWF(powf_f32, float)
+POWF(powf_f16, half)
+ELU(elu_f32, float)
+ELU(elu_f16, half)
#if __METAL_VERSION__ >= 310
-AFFINE(affine_bfloat, bfloat);
+AFFINE(affine_bf16, bfloat);
+POWF(powf_bf16, bfloat);
+ELU(elu_bf16, bfloat);
#endif
diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal
index f18cdbb0..8c3b4a8c 100644
--- a/candle-metal-kernels/src/binary.metal
+++ b/candle-metal-kernels/src/binary.metal
@@ -1,5 +1,8 @@
#include <metal_stdlib>
+#define MAX(x, y) ((x) > (y) ? (x) : (y))
+#define MIN(x, y) ((x) < (y) ? (x) : (y))
+
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
@@ -22,15 +25,15 @@ kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *left, \
device const TYPENAME *right, \
- device TYPENAME *output, \
- uint thread_position_in_grid [[ thread_position_in_grid ]] \
+ device OUT_TYPENAME *output, \
+ uint tid [[ thread_position_in_grid ]] \
) { \
- if (thread_position_in_grid >= dim) { \
+ if (tid >= dim) { \
return; \
} \
- TYPENAME x = left[thread_position_in_grid]; \
- TYPENAME y = right[thread_position_in_grid]; \
- output[thread_position_in_grid] = OUT_TYPENAME(FN); \
+ TYPENAME x = left[tid]; \
+ TYPENAME y = right[tid]; \
+ output[tid] = OUT_TYPENAME(FN); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@@ -40,33 +43,48 @@ kernel void FN_NAME_STRIDED( \
constant size_t *right_strides, \
device const TYPENAME *left, \
device const TYPENAME *right, \
- device TYPENAME *output, \
- uint thread_position_in_grid [[ thread_position_in_grid ]] \
+ device OUT_TYPENAME *output, \
+ uint tid [[ thread_position_in_grid ]] \
) { \
- if (thread_position_in_grid >= dim) { \
+ if (tid >= dim) { \
return; \
} \
- TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
- TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \
- output[thread_position_in_grid] = OUT_TYPENAME(FN); \
+ TYPENAME x = left[get_strided_index(tid, num_dims, dims, left_strides)]; \
+ TYPENAME y = right[get_strided_index(tid, num_dims, dims, right_strides)]; \
+ output[tid] = OUT_TYPENAME(FN); \
}
#define BINARY_OP(FN, NAME) \
-BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
-BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
+BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \
+BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided);
#define BFLOAT_BINARY_OP(FN, NAME) \
-BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
+BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
+
+#define BINARY_OP_OUT(NAME, FN) \
+BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
+BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided);
BINARY_OP(x + y, add)
BINARY_OP(x - y, sub)
BINARY_OP(x * y, mul)
BINARY_OP(x / y, div)
+BINARY_OP(MIN(x, y), min)
+BINARY_OP(MAX(x, y), max)
+
+BINARY_OP_OUT(eq, x == y)
+BINARY_OP_OUT(ne, x != y)
+BINARY_OP_OUT(le, x <= y)
+BINARY_OP_OUT(lt, x < y)
+BINARY_OP_OUT(ge, x >= y)
+BINARY_OP_OUT(gt, x > y)
#if __METAL_VERSION__ >= 310
BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub)
BFLOAT_BINARY_OP(x * y, mul)
BFLOAT_BINARY_OP(x / y, div)
+BFLOAT_BINARY_OP(MIN(x, y), min)
+BFLOAT_BINARY_OP(MAX(x, y), max)
#endif
diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal
index d1788253..8481389d 100644
--- a/candle-metal-kernels/src/cast.metal
+++ b/candle-metal-kernels/src/cast.metal
@@ -23,12 +23,12 @@ kernel void FN_NAME( \
constant size_t &dim, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
- uint thread_position_in_grid [[ thread_position_in_grid ]] \
+ uint tid [[ thread_position_in_grid ]] \
) { \
- if (thread_position_in_grid >= dim) { \
+ if (tid >= dim) { \
return; \
} \
- output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
+ output[tid] = RIGHT_TYPENAME(input[tid]); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@@ -37,15 +37,20 @@ kernel void FN_NAME_STRIDED( \
constant size_t *strides, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
- uint i [[ thread_position_in_grid ]] \
+ uint tid [[ thread_position_in_grid ]] \
) { \
- if (i >= dim) { \
+ if (tid >= dim) { \
return; \
} \
- output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
+ output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
} \
-CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
+CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
+CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
+CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
+CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
+CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
+CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
#if __METAL_VERSION__ >= 310
#endif
diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal
index 444fa322..63357428 100644
--- a/candle-metal-kernels/src/indexing.metal
+++ b/candle-metal-kernels/src/indexing.metal
@@ -1,6 +1,34 @@
#include <metal_stdlib>
using namespace metal;
+template<typename TYPENAME, typename INDEX_TYPENAME>
+METAL_FUNC void index(
+ constant size_t &dst_size,
+ constant size_t &left_size,
+ constant size_t &src_dim_size,
+ constant size_t &right_size,
+ constant size_t &ids_size,
+ const device TYPENAME *input,
+ const device INDEX_TYPENAME *input_ids,
+ device TYPENAME *output,
+ uint tid [[ thread_position_in_grid ]]
+) {
+ if (tid >= dst_size) {
+ return;
+ }
+ const size_t id_i = (tid / right_size) % ids_size;
+ const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
+ const size_t right_rank_i = tid % right_size;
+ const size_t left_rank_i = tid / right_size / ids_size;
+ /*
+ // Force prevent out of bounds indexing
+ // since there doesn't seem to be a good way to force crash
+ // No need to check for zero we're only allowing unsized.
+ */
+ const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
+ output[tid] = input[src_i];
+}
+
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
@@ -11,92 +39,160 @@ kernel void NAME( \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
- uint gid [[ thread_position_in_grid ]] \
+ uint tid [[ thread_position_in_grid ]] \
) { \
- if (gid >= dst_size) { \
- return; \
- } \
- const size_t id_i = gid / right_size / left_size; \
- const size_t right_rank_i = gid % right_size; \
- const size_t left_rank_i = gid % left_size; \
- /* \
- // Force prevent out of bounds indexing \
- // since there doesn't seem to be a good way to force crash \
- // No need to check for zero we're only allowing unsized. \
- */ \
- const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
- const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \
- output[gid] = input[src_i]; \
+ index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
}
+template<typename TYPENAME, typename INDEX_TYPENAME>
+METAL_FUNC void gather(
+ constant size_t &dst_size,
+ constant size_t &left_size,
+ constant size_t &src_dim_size,
+ constant size_t &right_size,
+ constant size_t &ids_size,
+ const device TYPENAME *input,
+ const device INDEX_TYPENAME *input_ids,
+ device TYPENAME *output,
+ uint tid [[ thread_position_in_grid ]]
+) {
+ if (tid >= dst_size) {
+ return;
+ }
+ const INDEX_TYPENAME input_i = input_ids[tid];
+ const size_t right_rank_i = tid % right_size;
+ const size_t left_rank_i = tid / right_size / ids_size;
+ const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
+ output[tid] = input[src_i];
+}
-template <typename T, typename I>
-void index_add(
- device I *ids [[buffer(0)]],
- device T *inp [[buffer(1)]],
- device T *out [[buffer(2)]],
-
- constant uint &ids_dim_size,
- constant uint &left_size,
- constant uint &dst_dim_size,
- constant uint &right_size,
-
- uint gid [[ thread_position_in_grid ]] \
-) {
+# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
+kernel void NAME( \
+ constant size_t &dst_size, \
+ constant size_t &left_size, \
+ constant size_t &src_dim_size, \
+ constant size_t &right_size, \
+ constant size_t &ids_size, \
+ const device TYPENAME *input, \
+ const device INDEX_TYPENAME *input_ids, \
+ device TYPENAME *output, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
+}
- if (gid >= left_size * right_size) {
- return;
+template<typename TYPENAME, typename INDEX_TYPENAME>
+METAL_FUNC void scatter_add(
+ constant size_t &dst_size,
+ constant size_t &left_size,
+ constant size_t &src_dim_size,
+ constant size_t &right_size,
+ constant size_t &dst_dim_size,
+ const device TYPENAME *input,
+ const device INDEX_TYPENAME *input_ids,
+ device TYPENAME *output,
+ uint tid [[ thread_position_in_grid ]]
+) {
+ if (tid >= dst_size) {
+ return;
+ }
+ const size_t right_rank_i = tid % right_size;
+ const size_t left_rank_i = tid / right_size;
+ for (unsigned int j = 0; j < src_dim_size; ++j) {
+ const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
+ const INDEX_TYPENAME idx = input_ids[src_i];
+ const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
+ output[dst_i] += input[src_i];
}
+}
- const uint i = gid;
- const uint pre = i / right_size;
- const uint post = i % right_size;
+# define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
+kernel void NAME( \
+ constant size_t &dst_size, \
+ constant size_t &left_size, \
+ constant size_t &src_dim_size, \
+ constant size_t &right_size, \
+ constant size_t &dst_dim_size, \
+ const device TYPENAME *input, \
+ const device INDEX_TYPENAME *input_ids, \
+ device TYPENAME *output, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ scatter_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \
+}
- for (uint j = 0; j < ids_dim_size; j++) {
- const uint idx = ids[j];
- const uint src_i = (pre * ids_dim_size + j) * right_size + post;
- const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
- out[dst_i] += inp[src_i];
+template<typename TYPENAME, typename INDEX_TYPENAME>
+METAL_FUNC void index_add(
+ constant size_t &dst_size,
+ constant size_t &left_size,
+ constant size_t &src_dim_size,
+ constant size_t &right_size,
+ constant size_t &dst_dim_size,
+ constant size_t &ids_dim_size,
+ const device TYPENAME *input,
+ const device INDEX_TYPENAME *input_ids,
+ device TYPENAME *output,
+ uint tid [[ thread_position_in_grid ]]
+) {
+ if (tid >= dst_size) {
+ return;
+ }
+ const size_t right_rank_i = tid % right_size;
+ const size_t left_rank_i = tid / right_size;
+ for (unsigned int j = 0; j < ids_dim_size; ++j) {
+ const INDEX_TYPENAME idx = input_ids[j];
+ const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
+ const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
+ output[dst_i] += input[src_i];
}
}
-#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
-kernel void FN_NAME( \
- device INDEX_TYPENAME *ids [[buffer(0)]], \
- device TYPENAME *inp [[buffer(1)]], \
- device TYPENAME *out [[buffer(2)]], \
- constant uint &ids_dim_size, \
- constant uint &left_size, \
- constant uint &dst_dim_size, \
- constant uint &right_size, \
- uint gid [[ thread_position_in_grid ]] \
-) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \
+# define INDEX_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
+kernel void NAME( \
+ constant size_t &dst_size, \
+ constant size_t &left_size, \
+ constant size_t &src_dim_size, \
+ constant size_t &right_size, \
+ constant size_t &dst_dim_size, \
+ constant size_t &ids_dim_size, \
+ const device TYPENAME *input, \
+ const device INDEX_TYPENAME *input_ids, \
+ device TYPENAME *output, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ index_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, ids_dim_size, input, input_ids, output, tid); \
+}
INDEX_OP(is_u32_f32, uint, float)
+INDEX_OP(is_u32_f16, uint, half)
+GATHER_OP(gather_u32_f32, uint, float)
+GATHER_OP(gather_u32_f16, uint, half)
+SCATTER_ADD_OP(sa_u32_f32, uint, float)
+SCATTER_ADD_OP(sa_u32_f16, uint, half)
#if __METAL_VERSION__ >= 310
-IA_OP(bfloat, int64_t, ia_i64_bf16)
-IA_OP(bfloat, uint32_t, ia_u32_bf16)
-IA_OP(bfloat, uint8_t, ia_u8_bf16)
+INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
+INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
+INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
#endif
-IA_OP(half, uint32_t, ia_u32_f16)
-IA_OP(half, uint8_t, ia_u8_f16)
+INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
+INDEX_ADD_OP(ia_u8_f16, uint8_t, half)
-IA_OP(float, int64_t, ia_i64_f32)
-IA_OP(uint8_t, int64_t, ia_i64_u8)
-IA_OP(int64_t, int64_t, ia_i64_i64)
-IA_OP(uint32_t, int64_t, ia_i64_u32)
+INDEX_ADD_OP(ia_i64_f32, int64_t, float)
+INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t)
+INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t)
+INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t)
-IA_OP(float, uint32_t, ia_u32_f32)
-IA_OP(uint8_t, uint32_t, ia_u32_u8)
-IA_OP(int64_t, uint32_t, ia_u32_i64)
-IA_OP(uint32_t, uint32_t, ia_u32_u32)
+INDEX_ADD_OP(ia_u32_f32, uint32_t, float)
+INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t)
+INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t)
+INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t)
-IA_OP(float, uint8_t, ia_u8_f32)
-IA_OP(uint8_t, uint8_t, ia_u8_u8)
-IA_OP(uint32_t, uint8_t, ia_u8_u32)
-IA_OP(int64_t, uint8_t, ia_u8_i64)
+INDEX_ADD_OP(ia_u8_f32, uint8_t, float)
+INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
+INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
+INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t)
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 5a6bd41b..0418c96c 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1,6 +1,6 @@
use metal::{
- Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor,
- ComputePipelineState, Device, Function, Library, MTLSize,
+ Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
+ Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
};
use std::collections::HashMap;
use std::ffi::c_void;
@@ -13,7 +13,12 @@ const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal");
+const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
+/// Most kernels apply similarly across the tensors
+/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
+/// actual total buffer length).
+/// Then kernels can just do their op on their single point in the buffer.
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
let size = length as u64;
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
@@ -35,6 +40,10 @@ fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTL
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
<P as EncoderParam>::set_param(encoder, position, data)
}
+
+/// Helper functions to create the various objects on the compute command encoder
+/// on a single line.
+/// Prevents getting wrong some arguments number and mixing length and size in bytes.
trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
@@ -59,8 +68,8 @@ impl<T> EncoderParam for &[T] {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
- (core::mem::size_of::<T>() * data.len()) as u64,
- data.as_ptr() as *const T as *const c_void,
+ core::mem::size_of_val(data) as u64,
+ data.as_ptr() as *const c_void,
);
}
}
@@ -105,54 +114,59 @@ pub enum Source {
Ternary,
Cast,
Reduce,
+ Mfa,
}
macro_rules! ops{
($($name:ident),+) => {
pub mod contiguous {
- #[derive(Clone, Copy)]
- pub struct Kernel(pub(crate) &'static str);
- impl std::fmt::Display for Kernel {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.0)
- }
- }
+ pub struct Kernel(pub &'static str);
$(
pub mod $name {
use super::Kernel;
- pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float"));
- pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half"));
- pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
+ pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32"));
+ pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
+ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
}
)+
+ pub mod copy {
+ use super::Kernel;
+ pub const FLOAT: Kernel = Kernel("copy_f32");
+ pub const HALF: Kernel = Kernel("copy_f16");
+ pub const BFLOAT: Kernel = Kernel("copy_bf16");
+ pub const U32: Kernel = Kernel("copy_u32");
+ pub const U8: Kernel = Kernel("copy_u8");
+ }
}
pub mod strided {
- #[derive(Clone, Copy)]
- pub struct Kernel(pub(crate) &'static str);
- impl std::fmt::Display for Kernel {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.0)
- }
- }
+ pub struct Kernel(pub &'static str);
$(
pub mod $name {
use super::Kernel;
- pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided"));
- pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided"));
- pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
+ pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided"));
+ pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
+ pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
}
)+
+ pub mod copy {
+ use super::Kernel;
+ pub const FLOAT: Kernel = Kernel("copy_f32_strided");
+ pub const HALF: Kernel = Kernel("copy_f16_strided");
+ pub const BFLOAT: Kernel = Kernel("copy_bf16_strided");
+ pub const U32: Kernel = Kernel("copy_u32_strided");
+ pub const U8: Kernel = Kernel("copy_u8_strided");
+ }
}
};
}
pub mod unary {
- ops!(cos, sin, exp, sqr, sqrt, neg, copy, log);
+ ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
}
pub mod binary {
- ops!(add, sub, mul, div);
+ ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);
}
#[derive(thiserror::Error, Debug)]
@@ -161,8 +175,18 @@ pub enum MetalKernelError {
LockError(String),
#[error("Error while loading library: {0}")]
LoadLibraryError(String),
- #[error("Error while loading function: {0}")]
+ #[error("Error while loading function: {0:?}")]
LoadFunctionError(String),
+ #[error("Failed to create compute function")]
+ FailedToCreateComputeFunction,
+ #[error("Failed to create pipeline")]
+ FailedToCreatePipeline(String),
+ #[error("Invalid matmul arguments {lhs_stride:?} {rhs_stride:?} {mnk:?}")]
+ MatMulNonContiguous {
+ lhs_stride: Vec<usize>,
+ rhs_stride: Vec<usize>,
+ mnk: (usize, usize, usize),
+ },
}
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
@@ -171,21 +195,25 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
}
}
-type KernelMap<T> = HashMap<&'static str, T>;
type Libraries = HashMap<Source, Library>;
-type Functions = KernelMap<Function>;
+type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipelineState>;
-#[derive(Debug, Default)]
+#[derive(Debug)]
pub struct Kernels {
libraries: RwLock<Libraries>,
- funcs: RwLock<Functions>,
+ pipelines: RwLock<Pipelines>,
+ fence: metal::Fence,
}
impl Kernels {
- pub fn new() -> Self {
+ pub fn new(fence: metal::Fence) -> Self {
let libraries = RwLock::new(Libraries::new());
- let funcs = RwLock::new(Functions::new());
- Self { libraries, funcs }
+ let pipelines = RwLock::new(Pipelines::new());
+ Self {
+ libraries,
+ pipelines,
+ fence,
+ }
}
fn get_library_source(&self, source: Source) -> &'static str {
@@ -197,9 +225,12 @@ impl Kernels {
Source::Indexing => INDEXING,
Source::Cast => CAST,
Source::Reduce => REDUCE,
+ Source::Mfa => panic!("Invalid lib"),
}
}
+ /// Load the give library from its [`source`].
+ /// If this has been previously loaded it will just fetch it from cache.
pub fn load_library(
&self,
device: &Device,
@@ -209,33 +240,83 @@ impl Kernels {
if let Some(lib) = libraries.get(&source) {
Ok(lib.clone())
} else {
- let source_content = self.get_library_source(source);
- let lib = device
- .new_library_with_source(source_content, &CompileOptions::new())
- .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
+ let lib = match source {
+ Source::Mfa => {
+ let source_data = MFA;
+ device.new_library_with_data(source_data).map_err(|e| {
+ MetalKernelError::LoadLibraryError(format!(
+ "Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
+ ))
+ })?
+ }
+ source => {
+ let source_content = self.get_library_source(source);
+ device
+ .new_library_with_source(source_content, &CompileOptions::new())
+ .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
+ }
+ };
libraries.insert(source, lib.clone());
Ok(lib)
}
}
- pub fn load_function(
+ fn load_function(
&self,
device: &Device,
source: Source,
name: &'static str,
+ constants: Option<FunctionConstantValues>,
) -> Result<Function, MetalKernelError> {
- let mut funcs = self.funcs.write()?;
- if let Some(func) = funcs.get(name) {
- Ok(func.clone())
+ let func = self
+ .load_library(device, source)?
+ .get_function(name, constants)
+ .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
+ Ok(func)
+ }
+
+ /// Load the give pipeline
+ /// loads the library from source, then gets the function [`name`] from
+ /// that source
+ fn load_pipeline_with_constants(
+ &self,
+ device: &Device,
+ source: Source,
+ name: &'static str,
+ constants: Option<ConstantValues>,
+ ) -> Result<ComputePipelineState, MetalKernelError> {
+ let mut pipelines = self.pipelines.write()?;
+ let key = (name, constants);
+ if let Some(pipeline) = pipelines.get(&key) {
+ Ok(pipeline.clone())
} else {
- let func = self
- .load_library(device, source)?
- .get_function(name, None)
- .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
- funcs.insert(name, func.clone());
- Ok(func)
+ let (name, constants) = key;
+ let func = self.load_function(
+ device,
+ source,
+ name,
+ constants.as_ref().map(|c| c.function_constant_values()),
+ )?;
+ let pipeline = device
+ .new_compute_pipeline_state_with_function(&func)
+ .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?;
+ pipelines.insert((name, constants), pipeline.clone());
+
+ Ok(pipeline)
}
}
+
+ /// Load the give pipeline
+ /// loads the library from source, then gets the function [`name`] from
+ /// that source (without constants)
+ pub fn load_pipeline(
+ &self,
+ device: &Device,
+ source: Source,
+ name: &'static str,
+ ) -> Result<ComputePipelineState, MetalKernelError> {
+ self.load_pipeline_with_constants(device, source, name, None)
+ }
}
#[allow(clippy::too_many_arguments)]
@@ -246,25 +327,20 @@ pub fn call_unary_contiguous(
kernel_name: unary::contiguous::Kernel,
length: usize,
input: &Buffer,
- output: &mut Buffer,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
-
+ let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -279,21 +355,14 @@ pub fn call_unary_strided(
input: &Buffer,
strides: &[usize],
offset: usize,
- output: &mut Buffer,
+ output: &Buffer,
output_offset: usize,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Unary, name.0)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
+ let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@@ -312,7 +381,10 @@ pub fn call_unary_strided(
let width: usize = shape.iter().product();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -326,26 +398,23 @@ pub fn call_binary_contiguous(
length: usize,
left: &Buffer,
right: &Buffer,
- output: &mut Buffer,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Binary, kernel_name.0)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
+ let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, left, right, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
+ encoder.use_resource(left, metal::MTLResourceUsage::Read);
+ encoder.use_resource(right, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -363,21 +432,14 @@ pub fn call_binary_strided(
right_input: &Buffer,
right_strides: &[usize],
right_offset: usize,
- output: &mut Buffer,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Binary, name.0)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
+ let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
let width: usize = shape.iter().product();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@@ -398,7 +460,11 @@ pub fn call_binary_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
+ encoder.use_resource(left_input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -411,31 +477,68 @@ pub fn call_cast_contiguous(
kernel_name: &'static str,
length: usize,
input: &Buffer,
- output: &mut Buffer,
+ input_offset: usize,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Cast, kernel_name)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
+ let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(encoder, (length, (input, input_offset), output));
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_cast_strided(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ kernel_name: &'static str,
+ shape: &[usize],
+ input: &Buffer,
+ input_strides: &[usize],
+ input_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (length, input, output));
+ let length: usize = shape.iter().product();
+
+ set_params!(
+ encoder,
+ (
+ length,
+ shape.len(),
+ shape,
+ input_strides,
+ (input, input_offset),
+ output
+ )
+ );
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
-#[allow(clippy::too_many_arguments)]
pub fn call_reduce_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -444,24 +547,78 @@ pub fn call_reduce_contiguous(
length: usize,
out_length: usize,
input: &Buffer,
- output: &mut Buffer,
+ input_offset: usize,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
+ let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
+ let elements_to_sum = length / out_length;
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (length, elements_to_sum, (input, input_offset), output)
+ );
+
+ let thread_group_count = MTLSize {
+ width: out_length as u64,
+ height: 1,
+ depth: 1,
+ };
+
+ let width = std::cmp::min(
+ pipeline.max_total_threads_per_threadgroup(),
+ (elements_to_sum as u64 + 2 - 1) / 2,
+ )
+ .next_power_of_two();
+ let thread_group_size = MTLSize {
+ width,
+ height: 1,
+ depth: 1,
+ };
+
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
+pub fn call_reduce_strided(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ kernel_name: &'static str,
+ shape: &[usize],
+ strides: &[usize],
+ out_length: usize,
+ input: &Buffer,
+ input_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let length: usize = shape.iter().product();
+ let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (length, elements_to_sum, input, output));
+ set_params!(
+ encoder,
+ (
+ shape.len(),
+ shape,
+ strides,
+ elements_to_sum,
+ (input, input_offset),
+ output
+ )
+ );
let thread_group_count = MTLSize {
width: out_length as u64,
@@ -471,7 +628,7 @@ pub fn call_reduce_contiguous(
let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
- (elements_to_sum as u64 + 2 - 1) / 2,
+ elements_to_sum as u64,
)
.next_power_of_two();
@@ -481,7 +638,10 @@ pub fn call_reduce_contiguous(
depth: 1,
};
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -495,22 +655,18 @@ pub fn call_last_softmax(
length: usize,
elements_to_sum: usize,
input: &Buffer,
- output: &mut Buffer,
+ input_offset: usize,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
-
+ let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (length, elements_to_sum, input, output));
+ set_params!(
+ encoder,
+ (length, elements_to_sum, (input, input_offset), output)
+ );
let out_length = length / elements_to_sum;
@@ -532,7 +688,10 @@ pub fn call_last_softmax(
depth: 1,
};
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -542,34 +701,214 @@ pub fn call_affine(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
+ name: &'static str,
size: usize,
input: &Buffer,
- output: &mut Buffer,
+ output: &Buffer,
+ mul: f32,
+ add: f32,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(encoder, (size, mul, add, input, output));
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_affine_strided(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ shape: &[usize],
+ input: &Buffer,
+ input_stride: &[usize],
+ input_offset: usize,
+ output: &Buffer,
mul: f32,
add: f32,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Affine, "affine_float")?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
+ let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
+ let size: usize = shape.iter().product();
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
+ set_params!(
+ encoder,
+ (
+ size,
+ shape.len(),
+ shape,
+ input_stride,
+ mul,
+ add,
+ (input, input_offset),
+ output
)
- .unwrap();
+ );
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_powf(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ size: usize,
+ input: &Buffer,
+ output: &Buffer,
+ mul: f32,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (size, mul, add, input, output));
+ set_params!(encoder, (size, mul, input, output));
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_powf_strided(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ shape: &[usize],
+ input: &Buffer,
+ input_stride: &[usize],
+ input_offset: usize,
+ output: &Buffer,
+ mul: f32,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
+ let size: usize = shape.iter().product();
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (
+ size,
+ shape.len(),
+ shape,
+ input_stride,
+ mul,
+ (input, input_offset),
+ output
+ )
+ );
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_elu(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ size: usize,
+ input: &Buffer,
+ output: &Buffer,
+ mul: f32,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(encoder, (size, mul, input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
+pub fn call_elu_strided(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ shape: &[usize],
+ input: &Buffer,
+ input_stride: &[usize],
+ input_offset: usize,
+ output: &Buffer,
+ mul: f32,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
+ let size: usize = shape.iter().product();
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (
+ size,
+ shape.len(),
+ shape,
+ input_stride,
+ mul,
+ (input, input_offset),
+ output
+ )
+ );
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
pub fn call_where_cond_strided(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -582,19 +921,12 @@ pub fn call_where_cond_strided(
(left_stride, left_offset): (&[usize], usize),
right: &Buffer,
(right_stride, right_offset): (&[usize], usize),
- output: &mut Buffer,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Ternary, name)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
+ let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let size: usize = shape.iter().product();
@@ -618,7 +950,12 @@ pub fn call_where_cond_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
+ encoder.use_resource(cond, metal::MTLResourceUsage::Read);
+ encoder.use_resource(left, metal::MTLResourceUsage::Read);
+ encoder.use_resource(right, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@@ -634,20 +971,18 @@ pub fn call_index_select(
dim: usize,
input: &Buffer,
ids: &Buffer,
- output: &mut Buffer,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let src_dim_size = shape[dim];
let dst_el = ids_size * left_size * right_size;
- let func = kernels.load_function(device, Source::Indexing, name)?;
- let pipeline = device
- .new_compute_pipeline_state_with_function(&func)
- .unwrap();
+ let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@@ -666,10 +1001,426 @@ pub fn call_index_select(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(ids, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
+#[allow(clippy::too_many_arguments)]
+pub fn call_gather(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ shape: &[usize],
+ ids_size: usize,
+ dim: usize,
+ input: &Buffer,
+ input_offset: usize,
+ ids: &Buffer,
+ ids_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let left_size: usize = shape[..dim].iter().product();
+ let right_size: usize = shape[dim + 1..].iter().product();
+ let src_dim_size = shape[dim];
+ let dst_el = ids_size * left_size * right_size;
+
+ let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
+
+ let encoder = command_buffer.new_compute_command_encoder();
+
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (
+ dst_el,
+ left_size,
+ src_dim_size,
+ right_size,
+ ids_size,
+ (input, input_offset),
+ (ids, ids_offset),
+ output
+ )
+ );
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
+
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(ids, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
+pub fn call_scatter_add(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ src_shape: &[usize],
+ dst_shape: &[usize],
+ dim: usize,
+ input: &Buffer,
+ input_offset: usize,
+ ids: &Buffer,
+ ids_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let left_size: usize = src_shape[..dim].iter().product();
+ let right_size: usize = src_shape[dim + 1..].iter().product();
+ let src_dim_size = src_shape[dim];
+ let dst_el = left_size * right_size;
+ let dst_dim_size = dst_shape[dim];
+
+ let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
+
+ let encoder = command_buffer.new_compute_command_encoder();
+
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (
+ dst_el,
+ left_size,
+ src_dim_size,
+ right_size,
+ dst_dim_size,
+ (input, input_offset),
+ (ids, ids_offset),
+ output
+ )
+ );
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
+
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(ids, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
+pub fn call_index_add(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ src_shape: &[usize],
+ dst_shape: &[usize],
+ ids_shape: &[usize],
+ dim: usize,
+ input: &Buffer,
+ input_offset: usize,
+ ids: &Buffer,
+ ids_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let left_size: usize = src_shape[..dim].iter().product();
+ let right_size: usize = src_shape[dim + 1..].iter().product();
+ let src_dim_size = src_shape[dim];
+ let dst_el = left_size * right_size;
+ let dst_dim_size = dst_shape[dim];
+ let ids_dim_size = ids_shape[0];
+
+ let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (
+ dst_el,
+ left_size,
+ src_dim_size,
+ right_size,
+ dst_dim_size,
+ ids_dim_size,
+ (input, input_offset),
+ (ids, ids_offset),
+ output
+ )
+ );
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
+
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(ids, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[derive(Debug, PartialEq)]
+pub enum Value {
+ USize(usize),
+ Bool(bool),
+ F32(f32),
+ U16(u16),
+}
+
+impl std::hash::Hash for Value {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ match self {
+ Value::F32(v) => v.to_bits().hash(state),
+ Value::USize(v) => v.hash(state),
+ Value::U16(v) => v.hash(state),
+ Value::Bool(v) => v.hash(state),
+ }
+ }
+}
+
+impl Value {
+ fn data_type(&self) -> MTLDataType {
+ match self {
+ Value::USize(_) => MTLDataType::UInt,
+ Value::F32(_) => MTLDataType::Float,
+ Value::U16(_) => MTLDataType::UShort,
+ Value::Bool(_) => MTLDataType::Bool,
+ }
+ }
+}
+
+/// Not true, good enough for our purposes.
+impl Eq for Value {}
+
+#[derive(Debug, Eq, PartialEq, Hash)]
+struct ConstantValues(Vec<(usize, Value)>);
+
+impl ConstantValues {
+ pub fn new(values: Vec<(usize, Value)>) -> Self {
+ Self(values)
+ }
+
+ fn function_constant_values(&self) -> FunctionConstantValues {
+ let f = FunctionConstantValues::new();
+ for (index, value) in &self.0 {
+ let ty = value.data_type();
+ match value {
+ Value::USize(v) => {
+ f.set_constant_value_at_index(
+ v as *const usize as *const c_void,
+ ty,
+ *index as u64,
+ );
+ }
+ Value::F32(v) => {
+ f.set_constant_value_at_index(
+ v as *const f32 as *const c_void,
+ ty,
+ *index as u64,
+ );
+ }
+ Value::U16(v) => {
+ f.set_constant_value_at_index(
+ v as *const u16 as *const c_void,
+ ty,
+ *index as u64,
+ );
+ }
+ Value::Bool(v) => {
+ f.set_constant_value_at_index(
+ v as *const bool as *const c_void,
+ ty,
+ *index as u64,
+ );
+ }
+ }
+ }
+ f
+ }
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_gemm(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ (b, m, n, k): (usize, usize, usize, usize),
+ lhs_stride: &[usize],
+ lhs_offset: usize,
+ lhs_buffer: &Buffer,
+ rhs_stride: &[usize],
+ rhs_offset: usize,
+ rhs_buffer: &Buffer,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ assert!(rhs_stride.len() >= 2);
+ assert!(lhs_stride.len() >= 2);
+ let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
+ let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
+ let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
+ let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
+ let a_trans = if lhs_m1 == 1 && lhs_m2 == k {
+ false
+ } else if lhs_m1 == m && lhs_m2 == 1 {
+ true
+ } else {
+ return Err(MetalKernelError::MatMulNonContiguous {
+ lhs_stride: lhs_stride.to_vec(),
+ rhs_stride: rhs_stride.to_vec(),
+ mnk: (m, n, k),
+ })?;
+ };
+ let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
+ false
+ } else if rhs_m1 == k && rhs_m2 == 1 {
+ true
+ } else {
+ return Err(MetalKernelError::MatMulNonContiguous {
+ lhs_stride: lhs_stride.to_vec(),
+ rhs_stride: rhs_stride.to_vec(),
+ mnk: (m, n, k),
+ })?;
+ };
+ let d_trans = false;
+ let alpha = 1.0f32;
+ let beta = 0.0f32;
+ let batched = b > 1;
+ let fused_activation = false;
+ let fused_bias = false;
+ let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
+ let m_simd = 16;
+ let n_simd = 8;
+ let k_simd = 64;
+ let m_splits = 1;
+ let n_splits = 1;
+ (m_simd, n_simd, k_simd, m_splits, n_splits)
+ } else {
+ let m_simd = 40;
+ let n_simd = 40;
+ let k_simd = 8;
+ let m_splits = 1;
+ let n_splits = 1;
+ (m_simd, n_simd, k_simd, m_splits, n_splits)
+ };
+ let constants = Some(ConstantValues::new(vec![
+ (0, Value::USize(m)),
+ (1, Value::USize(n)),
+ (2, Value::USize(k)),
+ (10, Value::Bool(a_trans)),
+ (11, Value::Bool(b_trans)),
+ (13, Value::Bool(d_trans)),
+ (20, Value::F32(alpha)),
+ (21, Value::F32(beta)),
+ (100, Value::Bool(batched)),
+ (101, Value::Bool(fused_activation)),
+ // Garbage
+ (102, Value::Bool(false)),
+ (103, Value::Bool(false)),
+ (113, Value::Bool(false)),
+ (50_000, Value::Bool(false)),
+ // End garbage
+ (200, Value::U16(m_simd)),
+ (201, Value::U16(n_simd)),
+ (202, Value::U16(k_simd)),
+ (210, Value::U16(m_splits)),
+ (211, Value::U16(n_splits)),
+ (50_001, Value::Bool(fused_bias)),
+ ]));
+ let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
+ let m_group = m_simd * m_splits;
+ let n_group = n_simd * n_splits;
+
+ let a_block_length = m_group * k_simd;
+ let b_block_length = k_simd * n_group;
+
+ let mut block_elements = a_block_length + b_block_length;
+ if (m % 8 != 0) && (n % 8 != 0) {
+ let c_block_length = m_group * n_group;
+ block_elements = std::cmp::max(c_block_length, block_elements)
+ }
+ if fused_bias {
+ if d_trans {
+ block_elements = std::cmp::max(block_elements, m_group);
+ } else {
+ block_elements = std::cmp::max(block_elements, n_group);
+ }
+ }
+ let bytes = match name {
+ "sgemm" => 4,
+ "hgemm" => 2,
+ other => {
+ return Err(MetalKernelError::LoadLibraryError(format!(
+ "{other} is not a valid kernel for gemm"
+ )));
+ }
+ };
+ let block_bytes = block_elements * bytes;
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+ encoder.set_threadgroup_memory_length(0, block_bytes.into());
+ encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
+ encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
+ encoder.set_buffer(2, Some(output), 0);
+ // TODO Tensor D
+
+ let grid_z = b;
+ if batched {
+ let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
+ let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
+ let byte_stride_c = m * n * bytes as usize;
+ // TODO byte_stride_d
+ let byte_stride_d = 0;
+
+ let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
+ for i in 0..b {
+ buffer.push((i * byte_stride_a) as u64);
+ buffer.push((i * byte_stride_b) as u64);
+ buffer.push((i * byte_stride_c) as u64);
+ buffer.push((i * byte_stride_d) as u64);
+ }
+ encoder.set_bytes(
+ 10,
+ (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
+ buffer.as_ptr() as *const NSUInteger as *const c_void,
+ );
+ }
+
+ let grid_size = MTLSize {
+ width: divide(n, n_group.into()),
+ height: divide(m, m_group.into()),
+ depth: grid_z as NSUInteger,
+ };
+ let group_size = MTLSize {
+ width: 32 * (m_splits as u64) * (n_splits as u64),
+ height: 1,
+ depth: 1,
+ };
+ // println!("grid size {grid_size:?} group size {group_size:?}");
+ encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(grid_size, group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+
+ Ok(())
+}
+
+fn divide(m: usize, b: usize) -> NSUInteger {
+ ((m + b - 1) / b) as NSUInteger
+}
+
#[cfg(test)]
mod tests;
diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib
new file mode 100644
index 00000000..f5116ca6
--- /dev/null
+++ b/candle-metal-kernels/src/libMetalFlashAttention.metallib
Binary files differ
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal
index c6984474..2d584917 100644
--- a/candle-metal-kernels/src/reduce.metal
+++ b/candle-metal-kernels/src/reduce.metal
@@ -1,6 +1,9 @@
#include <metal_stdlib>
using namespace metal;
+#define MAX(x, y) ((x) > (y) ? (x) : (y))
+#define MIN(x, y) ((x) < (y) ? (x) : (y))
+
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
@@ -16,39 +19,160 @@ METAL_FUNC uint get_strided_index(
return strided_i;
}
-constant int THREADGROUP_SIZE = 256;
+constant int THREADGROUP_SIZE = 2048;
+
+
+#define ARGMIN(NAME, T, MAXVALUE) \
+kernel void NAME( \
+ constant size_t &num_dims, \
+ constant size_t *dims, \
+ constant size_t *strides, \
+ constant size_t &el_to_sum_per_block, \
+ device const T *src, \
+ device uint *dst, \
+ uint id [[ thread_position_in_grid ]], \
+ uint tid [[ thread_index_in_threadgroup ]], \
+ uint dst_id [[ threadgroup_position_in_grid ]], \
+ uint block_dim [[ threads_per_threadgroup ]] \
+) { \
+ \
+ threadgroup T shared_memory[THREADGROUP_SIZE]; \
+ threadgroup uint shared_indices[THREADGROUP_SIZE]; \
+ \
+ shared_memory[tid] = MAXVALUE; \
+ shared_indices[tid] = 0xFFFFFFFF; \
+ bool notset = true; \
+ /* \
+ // Elements summed in this block range from dst_id * el_to_sum_per_block \
+ // to (dst_id + 1) * el_to_sum_per_block. \
+ */ \
+ size_t start_idx = dst_id * el_to_sum_per_block; \
+ size_t stop_idx = start_idx + el_to_sum_per_block; \
+ size_t idx = start_idx + tid; \
+ while (idx < stop_idx) { \
+ /* \
+ // TODO: Fast version for the contiguous case. \
+ */ \
+ size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
+ if (notset || src[strided_i] < shared_memory[tid]) { \
+ shared_memory[tid] = src[strided_i]; \
+ /* Assume that the reduction takes place over the last dimension which is contiguous. */ \
+ shared_indices[tid] = idx % dims[num_dims - 1]; \
+ notset = false; \
+ } \
+ idx += block_dim; \
+ } \
+ \
+ threadgroup_barrier(mem_flags::mem_none); \
+ \
+ /* \
+ // reduction in shared memory \
+ */ \
+ for (uint s = block_dim / 2; s > 0; s >>= 1) { \
+ if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
+ shared_indices[tid] = shared_indices[tid + s]; \
+ shared_memory[tid] = shared_memory[tid + s]; \
+ } \
+ threadgroup_barrier(mem_flags::mem_none); \
+ } \
+ \
+ if (tid == 0){ \
+ dst[dst_id] = shared_indices[0]; \
+ } \
+} \
+
-# define REDUCE(FN, NAME, TYPENAME) \
+#define ARGMAX(NAME, T, MINVALUE) \
kernel void NAME( \
- constant size_t &src_numel, \
+ constant size_t &num_dims, \
+ constant size_t *dims, \
+ constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
- device const TYPENAME *src, \
- device TYPENAME *dst, \
+ device const T *src, \
+ device uint *dst, \
+ uint id [[ thread_position_in_grid ]], \
+ uint tid [[ thread_index_in_threadgroup ]], \
+ uint dst_id [[ threadgroup_position_in_grid ]], \
+ uint block_dim [[ threads_per_threadgroup ]] \
+) { \
+ \
+ threadgroup T shared_memory[THREADGROUP_SIZE]; \
+ threadgroup uint shared_indices[THREADGROUP_SIZE]; \
+ \
+ shared_memory[tid] = MINVALUE; \
+ shared_indices[tid] = 0xFFFFFFFF; \
+ /* \
+ // Elements summed in this block range from dst_id * el_to_sum_per_block \
+ // to (dst_id + 1) * el_to_sum_per_block. \
+ */ \
+ size_t start_idx = dst_id * el_to_sum_per_block; \
+ size_t stop_idx = start_idx + el_to_sum_per_block; \
+ size_t idx = start_idx + tid; \
+ bool notset = true; \
+ while (idx < stop_idx) { \
+ /* \
+ // TODO: Fast version for the contiguous case. \
+ */ \
+ size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
+ if (notset || shared_memory[tid] < src[strided_i]) { \
+ shared_memory[tid] = src[strided_i]; \
+ shared_indices[tid] = idx % dims[num_dims - 1]; \
+ notset = false; \
+ } \
+ idx += block_dim; \
+ } \
+ \
+ threadgroup_barrier(mem_flags::mem_none); \
+ \
+ /* \
+ // reduction in shared memory \
+ */ \
+ for (uint s = block_dim / 2; s > 0; s >>= 1) { \
+ if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
+ shared_indices[tid] = shared_indices[tid + s]; \
+ shared_memory[tid] = shared_memory[tid + s]; \
+ } \
+ threadgroup_barrier(mem_flags::mem_none); \
+ } \
+ \
+ if (tid == 0){ \
+ dst[dst_id] = shared_indices[0]; \
+ } \
+} \
+
+#define REDUCE(FN, NAME, T, START) \
+kernel void NAME( \
+ constant size_t &num_dims, \
+ constant size_t *dims, \
+ constant size_t *strides, \
+ constant size_t &el_to_sum_per_block, \
+ device const T *src, \
+ device T *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
- uint blockDim [[ threads_per_threadgroup ]] \
+ uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
- threadgroup float shared_memory[THREADGROUP_SIZE]; \
+ threadgroup T shared_memory[THREADGROUP_SIZE]; \
\
- shared_memory[tid] = 0; \
+ shared_memory[tid] = START; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
- size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
+ size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
- // size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
*/ \
- TYPENAME x = shared_memory[tid]; \
- TYPENAME y = src[idx]; \
+ size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
+ T x = shared_memory[tid]; \
+ T y = src[strided_i]; \
shared_memory[tid] = FN; \
- idx += blockDim; \
+ idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
@@ -56,10 +180,10 @@ kernel void NAME( \
/* \
// reduction in shared memory \
*/ \
- for (uint s = blockDim / 2; s > 0; s >>= 1) { \
+ for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
- TYPENAME x = shared_memory[tid]; \
- TYPENAME y = shared_memory[tid + s]; \
+ T x = shared_memory[tid]; \
+ T y = shared_memory[tid + s]; \
shared_memory[tid] = FN; \
} \
threadgroup_barrier(mem_flags::mem_none); \
@@ -68,72 +192,101 @@ kernel void NAME( \
dst[dst_id] = shared_memory[0]; \
} \
-kernel void softmax_float(
- constant size_t &src_numel,
- constant size_t &el_to_sum_per_block,
- device const float *src,
- device float *dst,
- uint id [[ thread_position_in_grid ]],
- uint tid [[ thread_index_in_threadgroup ]],
- uint dst_id [[ threadgroup_position_in_grid ]],
- uint blockDim [[ threads_per_threadgroup ]]
-) {
-
- threadgroup float shared_memory[THREADGROUP_SIZE];
-
- shared_memory[tid] = -INFINITY;
- // Elements summed in this block range from dst_id * el_to_sum_per_block
- // to (dst_id + 1) * el_to_sum_per_block.
- size_t start_idx = dst_id * el_to_sum_per_block;
- size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
- size_t idx = start_idx + tid;
-
- while (idx < stop_idx) {
- // TODO: Fast version for the contiguous case.
- shared_memory[tid] = max(shared_memory[tid], src[idx]);
- idx += blockDim;
- }
-
- threadgroup_barrier(mem_flags::mem_none);
-
- // reduction in shared memory
- for (uint s = blockDim / 2; s > 0; s >>= 1) {
- if (tid < s) {
- shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
- }
- threadgroup_barrier(mem_flags::mem_none);
- }
-
- float max = shared_memory[0];
-
- shared_memory[tid] = 0;
- // Restart
- idx = start_idx + tid;
- while (idx < stop_idx) {
- // TODO: Fast version for the contiguous case.
- const float val = exp(src[idx] - max);
- dst[idx] = val;
- shared_memory[tid] += val;
- idx += blockDim;
- }
- // reduction in shared memory
- for (uint s = blockDim / 2; s > 0; s >>= 1) {
- if (tid < s) {
- shared_memory[tid] += shared_memory[tid + s];
- }
- threadgroup_barrier(mem_flags::mem_none);
- }
-
- const float inv_acc = 1/shared_memory[0];
- idx = start_idx + tid;
- while (idx < stop_idx) {
- dst[idx] *= inv_acc;
- idx += blockDim;
- }
-}
+#define SOFTMAX(NAME, T) \
+kernel void NAME( \
+ constant size_t &src_numel, \
+ constant size_t &el_to_sum_per_block, \
+ device const T *src, \
+ device T *dst, \
+ \
+ uint id [[ thread_position_in_grid ]], \
+ uint tid [[ thread_index_in_threadgroup ]], \
+ uint dst_id [[ threadgroup_position_in_grid ]], \
+ uint block_dim [[ threads_per_threadgroup ]] \
+) { \
+ threadgroup float shared_memory[THREADGROUP_SIZE]; \
+ shared_memory[tid] = -INFINITY; \
+ size_t start_idx = dst_id * el_to_sum_per_block; \
+ size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
+ size_t idx = start_idx + tid; \
+ \
+ \
+ float tmp = -INFINITY; \
+ while (idx < stop_idx) { \
+ tmp = MAX(tmp, float(src[idx])); \
+ idx += block_dim; \
+ } \
+ shared_memory[tid] = tmp; \
+ \
+ threadgroup_barrier(mem_flags::mem_threadgroup); \
+ \
+ for (uint s = block_dim / 2; s > 0; s >>= 1) { \
+ if (tid < s) { \
+ shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
+ } \
+ threadgroup_barrier(mem_flags::mem_threadgroup); \
+ } \
+ \
+ /* wait for shared_memory[0] to be filled */ \
+ threadgroup_barrier(mem_flags::mem_threadgroup); \
+ \
+ float _max = shared_memory[0]; \
+ \
+ /* prevent tid=0 from overwriting _max before other threads have written it */ \
+ threadgroup_barrier(mem_flags::mem_threadgroup); \
+ shared_memory[tid] = 0; \
+ \
+ idx = start_idx + tid; \
+ while (idx < stop_idx) { \
+ const float val = exp(float(src[idx]) - _max); \
+ dst[idx] = T(val); \
+ shared_memory[tid] += val; \
+ idx += block_dim; \
+ } \
+ threadgroup_barrier(mem_flags::mem_threadgroup); \
+ for (uint s = block_dim / 2; s > 0; s >>= 1) { \
+ if (tid < s) { \
+ shared_memory[tid] += shared_memory[tid + s]; \
+ } \
+ threadgroup_barrier(mem_flags::mem_threadgroup); \
+ } \
+ \
+ const T inv_acc = T(1.0/shared_memory[0]); \
+ idx = start_idx + tid; \
+ while (idx < stop_idx) { \
+ dst[idx] *= inv_acc; \
+ idx += block_dim; \
+ } \
+} \
+REDUCE(x + y, fast_sum_f32_strided, float, 0)
+REDUCE(x + y, fast_sum_u32_strided, uint, 0)
+REDUCE(x + y, fast_sum_f16_strided, half, 0)
+REDUCE(x * y, fast_mul_f32_strided, float, 1)
+REDUCE(x * y, fast_mul_u32_strided, uint, 1)
+REDUCE(x * y, fast_mul_f16_strided, half, 1)
+REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
+REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
+REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
+REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
+REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
+REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
+ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
+ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
+ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
+ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
+ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
+ARGMAX(fast_argmax_u32_strided, uint, 0)
-REDUCE(x + y, fast_sum_float, float)
-REDUCE(x * y, fast_mul_float, float)
-REDUCE(max(x, y), fast_max_float, float)
+SOFTMAX(softmax_f32, float)
+SOFTMAX(softmax_f16, half)
+#if __METAL_VERSION__ >= 310
+REDUCE(x + y, fast_sum_bf16, bfloat, 0)
+REDUCE(x * y, fast_mul_bf16, bfloat, 1)
+REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
+REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
+ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
+ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
+SOFTMAX(softmax_bf16, bfloat)
+#endif
diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal
index 0945b355..1f9cb38a 100644
--- a/candle-metal-kernels/src/ternary.metal
+++ b/candle-metal-kernels/src/ternary.metal
@@ -32,6 +32,9 @@ kernel void FN_NAME( \
device TYPENAME *out ,\
uint i [[ thread_position_in_grid ]] \
) { \
+ if (i >= numel){ \
+ return; \
+ } \
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index 2330d48d..1b3153b1 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -1,7 +1,14 @@
use super::*;
-use half::f16;
+use half::{bf16, f16};
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
+fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
+ let ptr = buffer.contents() as *const T;
+ assert!(!ptr.is_null());
+ let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
+ slice.to_vec()
+}
+
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
let options = MTLResourceOptions::StorageModeManaged;
let ptr = data.as_ptr() as *const core::ffi::c_void;
@@ -23,13 +30,19 @@ fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
}
+fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
+ let b = 10f32.powi(digits);
+ v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
+}
+
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
let device = device();
- let kernels = Kernels::new();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
+ let output = new_buffer(&device, v);
call_unary_contiguous(
&device,
command_buffer,
@@ -37,23 +50,24 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
name,
v.len(),
&input,
- &mut output,
+ &output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
- output.read_to_vec::<T>(v.len())
+ read_to_vec(&output, v.len())
}
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
let device = device();
- let kernels = Kernels::new();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let left = new_buffer(&device, x);
let right = new_buffer(&device, y);
- let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
+ let output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
call_binary_contiguous(
&device,
command_buffer,
@@ -62,12 +76,12 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
x.len(),
&left,
&right,
- &mut output,
+ &output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
- output.read_to_vec::<T>(x.len())
+ read_to_vec(&output, x.len())
}
fn run_strided<T: Clone>(
@@ -81,8 +95,9 @@ fn run_strided<T: Clone>(
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
- let kernels = Kernels::new();
+ let output = new_buffer(&device, v);
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
call_unary_strided(
&device,
command_buffer,
@@ -92,13 +107,13 @@ fn run_strided<T: Clone>(
&input,
strides,
offset,
- &mut output,
+ &output,
0,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
- output.read_to_vec::<T>(v.len())
+ read_to_vec(&output, v.len())
}
#[test]
@@ -201,6 +216,25 @@ fn cos_strided_random() {
}
#[test]
+fn gelu_f16() {
+ let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]
+ .iter()
+ .map(|v| f16::from_f32(*v))
+ .collect();
+ let expected: Vec<f32> = vec![-0.0, -0.16, 0.0, 0.84, 1.96, 3.0, 10.0, 20.0];
+ let results = run(&v, unary::contiguous::gelu::HALF);
+ assert_eq!(approx_f16(results, 2), expected);
+}
+
+#[test]
+fn gelu_f32() {
+ let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0];
+ let expected: Vec<f32> = vec![-0.0, -0.159, 0.0, 0.841, 1.955, 2.996, 10.0, 20.0];
+ let results = run(&v, unary::contiguous::gelu::FLOAT);
+ assert_eq!(approx(results, 3), expected);
+}
+
+#[test]
fn binary_add_f32() {
let left = vec![1.0f32, 2.0, 3.0];
let right = vec![2.0f32, 3.1, 4.2];
@@ -216,11 +250,14 @@ fn binary_add_f32() {
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device();
- let kernels = Kernels::new();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
+ let options = MTLResourceOptions::StorageModeManaged;
+ let size = (v.len() * std::mem::size_of::<U>()) as u64;
+ let output = device.new_buffer(size, options);
call_cast_contiguous(
&device,
@@ -229,12 +266,13 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
name,
v.len(),
&input,
- &mut output,
+ 0,
+ &output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
- output.read_to_vec::<U>(v.len())
+ read_to_vec(&output, v.len())
}
#[test]
@@ -245,21 +283,28 @@ fn cast_u32_f32() {
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
+ let v = vec![1.0f32, 2.0, 3.0];
+ let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
+ let results: Vec<f32> = cast(&input, "cast_f16_f32");
+ assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
+
let v = vec![1.0f32; 10_000];
- let results = run(&v, unary::contiguous::cos::FLOAT);
- let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
- assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
- assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
+ let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
+ let results: Vec<f32> = cast(&input, "cast_f16_f32");
+ assert_eq!(results.len(), 10_000);
+ assert_eq!(&results[..10], vec![1.0f32; 10]);
+ assert_eq!(results, vec![1.0f32; 10_000]);
}
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let device = device();
- let kernels = Kernels::new();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
+ let output = new_buffer(&device, v);
let size = v.len();
@@ -267,9 +312,46 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
&device,
command_buffer,
&kernels,
+ "affine_f32",
size,
&input,
- &mut output,
+ &output,
+ mul as f32,
+ add as f32,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ read_to_vec(&output, v.len())
+}
+
+fn run_affine_strided<T: Clone>(
+ v: &[T],
+ shape: &[usize],
+ strides: &[usize],
+ mul: f64,
+ add: f64,
+) -> Vec<T> {
+ let device = device();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+
+ let input = new_buffer(&device, v);
+ let output = new_buffer(&device, v);
+
+ call_affine_strided(
+ &device,
+ command_buffer,
+ &kernels,
+ "affine_f32_strided",
+ shape,
+ &input,
+ strides,
+ 0,
+ &output,
mul as f32,
add as f32,
)
@@ -277,7 +359,8 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
command_buffer.commit();
command_buffer.wait_until_completed();
- output.read_to_vec::<T>(v.len())
+ let len: usize = shape.iter().product();
+ read_to_vec(&output, len)
}
#[test]
@@ -296,6 +379,18 @@ fn affine() {
}
#[test]
+fn affine_strided() {
+ let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
+ let mul = 1.5;
+ let add = 1.1;
+ let shape = [4];
+ let strides = [2];
+ let result = run_affine_strided(&input, &shape, &strides, mul, add);
+ // 1 on 2
+ assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]);
+}
+
+#[test]
fn index_select() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
@@ -313,7 +408,26 @@ fn index_select() {
result,
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
);
+}
+#[test]
+fn index_select_f16() {
+ let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
+ .into_iter()
+ .map(|x| f16::from_f32(x))
+ .collect();
+ let shape = [5, 2];
+ let ids = [0u32, 4, 2];
+ let dim = 0;
+ let result = run_index_select(&embedding, &shape, &ids, dim);
+ assert_eq!(
+ approx_f16(result, 4),
+ vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
+ );
+}
+
+#[test]
+fn index_select_dim1() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
let ids = [0u32, 1, 0];
@@ -321,7 +435,7 @@ fn index_select() {
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(
result,
- vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
+ vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
);
}
@@ -341,27 +455,34 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let dst_el = ids.len() * left_size * right_size;
- let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
+ let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
+
+ let name = match core::mem::size_of::<T>() {
+ 4 => "is_u32_f32",
+ 2 => "is_u32_f16",
+ _ => unimplemented!(),
+ };
- let kernels = Kernels::new();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
call_index_select(
&device,
&command_buffer,
&kernels,
- "is_u32_f32",
+ name,
shape,
ids.len(),
dim,
&embeddings_buffer,
&ids_buffer,
- &mut dst_buffer,
+ &dst_buffer,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
- dst_buffer.read_to_vec::<T>(dst_el)
+ read_to_vec(&dst_buffer, dst_el)
}
#[test]
@@ -427,7 +548,7 @@ fn index_add() {
let expected = vec![
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
];
- let result = outputs_buffer.read_to_vec::<f32>(right.len());
+ let result: Vec<f32> = read_to_vec(&outputs_buffer, right.len());
assert_eq!(result, expected);
}
@@ -439,43 +560,49 @@ fn cos_f16() {
.collect();
let results = run(&v, unary::contiguous::cos::HALF);
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
- assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]);
- assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
+ assert_eq!(approx_f16(results, 2), vec![0.54, -0.42, -0.99]);
+ assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
}
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
let device = device();
- let kernels = Kernels::new();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let options = MTLResourceOptions::StorageModeManaged;
- let mut output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
- call_reduce_contiguous(
+ let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
+ let dims = vec![v.len()];
+ let strides = vec![1];
+ call_reduce_strided(
&device,
command_buffer,
&kernels,
name,
- v.len(),
+ &dims,
+ &strides,
out_length,
&input,
- &mut output,
+ 0,
+ &output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
- output.read_to_vec::<T>(out_length)
+ read_to_vec(&output, out_length)
}
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
let device = device();
- let kernels = Kernels::new();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
+ let output = new_buffer(&device, v);
call_last_softmax(
&device,
command_buffer,
@@ -484,13 +611,14 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
v.len(),
last_dim,
&input,
- &mut output,
+ 0,
+ &output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
- output.read_to_vec::<T>(v.len())
+ read_to_vec(&output, v.len())
}
#[test]
@@ -498,7 +626,7 @@ fn reduce_sum() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 1;
- let results = run_reduce(&v, out_length, "fast_sum_float");
+ let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
assert_eq!(approx(results, 4), vec![21.0]);
}
@@ -507,7 +635,7 @@ fn reduce_sum2() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 2;
- let results = run_reduce(&v, out_length, "fast_sum_float");
+ let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
}
@@ -515,15 +643,33 @@ fn reduce_sum2() {
fn softmax() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let last_dim = 6;
- let results = run_softmax(&v, last_dim, "softmax_float");
+ let results = run_softmax(&v, last_dim, "softmax_f32");
assert_eq!(
approx(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
);
+ let last_dim = 4096;
+ let n = 200;
+ let mut v = vec![0.0; n * last_dim];
+ for i in 0..n {
+ v[i * last_dim] = 20.0;
+ }
+ let results = run_softmax(&v, last_dim, "softmax_f32");
+ let results = approx(results, 4);
+ println!("{results:?}");
+ assert_eq!(
+ results.iter().map(|&s| s.round() as usize).sum::<usize>(),
+ n
+ );
+ assert_eq!(results[0], 1.0);
+ assert_eq!(results[1], 0.0);
+ assert_eq!(results[last_dim], 1.0);
+ assert_eq!(results[2 * last_dim], 1.0);
+
let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
let last_dim = 6;
- let results = run_softmax(&v, last_dim, "softmax_float");
+ let results = run_softmax(&v, last_dim, "softmax_f32");
assert_eq!(
approx(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
@@ -531,11 +677,33 @@ fn softmax() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let last_dim = 3;
- let results = run_softmax(&v, last_dim, "softmax_float");
+ let results = run_softmax(&v, last_dim, "softmax_f32");
assert_eq!(
approx(results, 4),
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
);
+
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
+ .iter()
+ .map(|v| f16::from_f32(*v))
+ .collect::<Vec<_>>();
+ let last_dim = 6;
+ let results = run_softmax(&v, last_dim, "softmax_f16");
+ assert_eq!(
+ approx_f16(results, 4),
+ vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
+ );
+
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
+ .iter()
+ .map(|v| bf16::from_f32(*v))
+ .collect::<Vec<_>>();
+ let last_dim = 6;
+ let results = run_softmax(&v, last_dim, "softmax_bf16");
+ assert_eq!(
+ approx_bf16(results, 4),
+ vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]
+ );
}
fn run_where_cond<I: Clone, T: Clone>(
@@ -549,7 +717,8 @@ fn run_where_cond<I: Clone, T: Clone>(
name: &'static str,
) -> Vec<T> {
let device = device();
- let kernels = Kernels::new();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
@@ -571,7 +740,7 @@ fn run_where_cond<I: Clone, T: Clone>(
options,
);
- let mut output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
+ let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
call_where_cond_strided(
&device,
command_buffer,
@@ -584,13 +753,13 @@ fn run_where_cond<I: Clone, T: Clone>(
(&left_stride, left_offset),
&right,
(&cond_stride, cond_offset),
- &mut output,
+ &output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
- output.read_to_vec::<T>(length)
+ read_to_vec(&output, length)
}
#[test]
@@ -614,3 +783,93 @@ fn where_cond() {
);
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
}
+
+fn run_gemm<T: Clone>(
+ (b, m, n, k): (usize, usize, usize, usize),
+ lhs: &[T],
+ lhs_stride: Vec<usize>,
+ lhs_offset: usize,
+ rhs: &[T],
+ rhs_stride: Vec<usize>,
+ rhs_offset: usize,
+) -> Vec<T> {
+ let device = device();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let options = MTLResourceOptions::StorageModeManaged;
+
+ let lhs = device.new_buffer_with_data(
+ lhs.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(lhs) as u64,
+ options,
+ );
+ let rhs = device.new_buffer_with_data(
+ rhs.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(rhs) as u64,
+ options,
+ );
+ let length = b * m * n;
+ let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
+ call_gemm(
+ &device,
+ command_buffer,
+ &kernels,
+ "sgemm",
+ (b, m, n, k),
+ &lhs_stride,
+ lhs_offset,
+ &lhs,
+ &rhs_stride,
+ rhs_offset,
+ &rhs,
+ &output,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ read_to_vec(&output, length)
+}
+
+#[test]
+fn gemm() {
+ let (b, m, n, k) = (1, 2, 4, 3);
+ let lhs_stride = vec![m * k, k, 1];
+ let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
+ let rhs_stride = vec![n * k, n, 1];
+ let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
+ let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
+ assert_eq!(
+ approx(results, 4),
+ vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
+ );
+
+ let (b, m, n, k) = (2, 2, 4, 3);
+ let lhs_stride = vec![m * k, k, 1];
+ let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
+ let rhs_stride = vec![n * k, n, 1];
+ let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
+ let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
+ assert_eq!(
+ approx(results, 4),
+ vec![
+ 20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
+ 518.0, 548.0, 578.0
+ ]
+ );
+
+ // OFFSET
+ let (b, m, n, k) = (2, 2, 4, 3);
+ let lhs_stride = vec![m * k, k, 1];
+ let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
+ let rhs_stride = vec![n * k, n, 1];
+ let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
+ // Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
+ let results = run_gemm((1, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 12 * 4);
+ assert_eq!(
+ approx(results, 4),
+ vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
+ );
+}
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index eb6424e8..553bc506 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -1,4 +1,7 @@
#include <metal_stdlib>
+#include <metal_math>
+#
+using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
@@ -17,10 +20,44 @@ METAL_FUNC uint get_strided_index(
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
template <typename T> METAL_FUNC T neg(T in){ return -in; }
-template <typename T> METAL_FUNC T id(T in){ return in; }
+template <typename T> METAL_FUNC T erf(T in){
+ float x = (float) in;
+ // constants
+ float a1 = 0.254829592;
+ float a2 = -0.284496736;
+ float a3 = 1.421413741;
+ float a4 = -1.453152027;
+ float a5 = 1.061405429;
+ float p = 0.3275911;
+
+ // Save the sign of x
+ int sign = 1;
+ if (x < 0)
+ sign = -1;
+ x = fabs(x);
+
+ // A&S formula 7.1.26
+ float t = 1.0/(1.0 + p*x);
+ float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
+
+ return T(sign*y);
+}
+template <typename T> METAL_FUNC T id(T in) { return in; }
+template <typename T> METAL_FUNC T gelu_erf(T x) {
+ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2);
+}
+template <typename T> METAL_FUNC T gelu(T x) {
+ if (x > 5) {
+ return x;
+ }
+ T x_sq = x * x;
+ T x_cube = x_sq * x;
+ T alpha = x + static_cast<T>(0.044715) * x_cube;
+ T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
+ return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
+}
-using namespace metal;
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
@@ -32,7 +69,7 @@ kernel void FN_NAME( \
if (thread_position_in_grid >= dim) { \
return; \
} \
- output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \
+ output[thread_position_in_grid] = TYPENAME(FN(float(input[thread_position_in_grid]))); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@@ -46,15 +83,15 @@ kernel void FN_NAME_STRIDED( \
if (thread_position_in_grid >= dim) { \
return; \
} \
- output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \
+ output[thread_position_in_grid] = TYPENAME(FN(float(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)]))); \
}
#define UNARY_OP(NAME) \
-UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
-UNARY(NAME, half, NAME##_half, NAME##_half_strided);
+UNARY(NAME, float, NAME##_f32, NAME##_f32_strided); \
+UNARY(NAME, half, NAME##_f16, NAME##_f16_strided);
#define BFLOAT_UNARY_OP(NAME) \
-UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
+UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided);
UNARY_OP(cos)
@@ -64,8 +101,17 @@ UNARY_OP(sqrt)
UNARY_OP(neg)
UNARY_OP(exp)
UNARY_OP(log)
-UNARY(id, float, copy_float, copy_float_strided)
-UNARY(id, half, copy_half, copy_half_strided)
+UNARY_OP(gelu)
+UNARY_OP(ceil)
+UNARY_OP(floor)
+UNARY_OP(round)
+UNARY_OP(gelu_erf)
+UNARY_OP(erf)
+UNARY_OP(tanh)
+UNARY(id, float, copy_f32, copy_f32_strided)
+UNARY(id, half, copy_f16, copy_f16_strided)
+UNARY(id, uint8_t, copy_u8, copy_u8_strided)
+UNARY(id, uint32_t, copy_u32, copy_u32_strided)
#if __METAL_VERSION__ >= 310
BFLOAT_UNARY_OP(cos)
@@ -75,6 +121,13 @@ BFLOAT_UNARY_OP(sqrt)
BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp)
BFLOAT_UNARY_OP(log)
+BFLOAT_UNARY_OP(gelu)
+BFLOAT_UNARY_OP(ceil)
+BFLOAT_UNARY_OP(floor)
+BFLOAT_UNARY_OP(round)
+BFLOAT_UNARY_OP(gelu_erf)
+BFLOAT_UNARY_OP(erf)
+BFLOAT_UNARY_OP(tanh)
-UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
+UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
#endif
diff --git a/candle-metal-kernels/examples/affine.rs b/candle-metal-kernels/tmp/affine.rs
index b8005dc0..cd019056 100644
--- a/candle-metal-kernels/examples/affine.rs
+++ b/candle-metal-kernels/tmp/affine.rs
@@ -50,6 +50,7 @@ fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
&device,
command_buffer,
&kernels,
+ "affine_float",
v.len(),
&input,
&mut output,
diff --git a/candle-metal-kernels/examples/binary.rs b/candle-metal-kernels/tmp/binary.rs
index af5a8bdc..af5a8bdc 100644
--- a/candle-metal-kernels/examples/binary.rs
+++ b/candle-metal-kernels/tmp/binary.rs
diff --git a/candle-metal-kernels/examples/cast.rs b/candle-metal-kernels/tmp/cast.rs
index 090f510d..090f510d 100644
--- a/candle-metal-kernels/examples/cast.rs
+++ b/candle-metal-kernels/tmp/cast.rs
diff --git a/candle-metal-kernels/examples/unary.rs b/candle-metal-kernels/tmp/unary.rs
index 7039c098..66cf25c0 100644
--- a/candle-metal-kernels/examples/unary.rs
+++ b/candle-metal-kernels/tmp/unary.rs
@@ -147,7 +147,7 @@ fn run_unary_bench<T: Clone>(
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
- kernel_name.to_string(),
+ kernel_name.0,
v.len(),
iterations,
total_time,
@@ -159,7 +159,7 @@ fn run_unary_bench<T: Clone>(
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
- for kernel_name in strided {
+ for kernel_name in &strided {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
@@ -187,7 +187,7 @@ fn run_unary_bench<T: Clone>(
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
- kernel_name.to_string(),
+ kernel_name.0,
v.len(),
iterations,
total_time,
diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml
index ffbe0ca1..e0daabef 100644
--- a/candle-nn/Cargo.toml
+++ b/candle-nn/Cargo.toml
@@ -19,6 +19,8 @@ num-traits = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
+metal = { workspace = true, optional = true }
+candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
[dev-dependencies]
anyhow = { workspace = true }
@@ -29,3 +31,4 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
+metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"]
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index a0269e59..abe33350 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -201,6 +201,47 @@ impl candle::CustomOp1 for SoftmaxLastDim {
};
Ok((dst, layout.shape().clone()))
}
+
+ #[cfg(feature = "metal")]
+ fn metal_fwd(
+ &self,
+ storage: &candle::MetalStorage,
+ layout: &Layout,
+ ) -> Result<(candle::MetalStorage, Shape)> {
+ use candle::{backend::BackendStorage, DType};
+ let device = storage.device();
+ let command_buffer = device.command_buffer()?;
+ let kernels = device.kernels();
+ let name = match storage.dtype() {
+ DType::F32 => "softmax_f32",
+ DType::F16 => "softmax_f16",
+ DType::BF16 => "softmax_bf16",
+ dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
+ };
+
+ let n = layout.stride().len();
+ if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
+ candle::bail!("Non contiguous softmax-last-dim is not implemented");
+ }
+
+ let last_dim = layout.dims()[layout.shape().rank() - 1];
+ let elem_count = layout.shape().elem_count();
+ let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
+ candle_metal_kernels::call_last_softmax(
+ device.metal_device(),
+ &command_buffer,
+ kernels,
+ name,
+ elem_count,
+ last_dim,
+ storage.buffer(),
+ layout.start_offset() * storage.dtype().size_in_bytes(),
+ &output,
+ )
+ .unwrap();
+ let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
+ Ok((newstorage, layout.shape().clone()))
+ }
}
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml
index 37e5fa65..000702f9 100644
--- a/candle-transformers/Cargo.toml
+++ b/candle-transformers/Cargo.toml
@@ -31,3 +31,4 @@ accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
+metal = ["candle/metal", "candle-nn/metal"]