summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-11-11 01:02:15 +0100
committerNicolas Patry <patry.nicolas@protonmail.com>2023-11-30 11:30:31 +0100
commit4349ff1fc29a1a25b2ccdf56fbf68a98f5364c0a (patch)
tree78a6b3533670a33f7bc2f75851fac24307a46fed
parent7c3cfd1086ecdc08a0b350f30f1fbedf2f00c269 (diff)
downloadcandle-4349ff1fc29a1a25b2ccdf56fbf68a98f5364c0a.tar.gz
candle-4349ff1fc29a1a25b2ccdf56fbf68a98f5364c0a.tar.bz2
candle-4349ff1fc29a1a25b2ccdf56fbf68a98f5364c0a.zip
Starting to fix some tests.
Few fixes. Going back on remote metal-rs. Reusing a single buffer (for now) to speed things up. Adding some half kernels. All tests are panicking instead of random failure. Putting back f16 index select. Add erf. Working version for llama2-c. Fixes + cache compute_pipeline_state. BF16 metal fix. Remove some prints. new_owned -> new()..to_owned(). Better batched matmul. Metal operational. Reuse buffers on our own reference counts. Tmp gemm. Revert "Tmp gemm." This reverts commit c65f68e98814b65daa596696bda076a73303dd82. Interleave committing. Speeding up copies using blit. Fmt. Fmt. Remove the assert! Fmt all. Fixes after big rebase. Add softmax for half and bfloat + tests Fixing Llama example + accumulate softmax in float.
-rw-r--r--candle-core/src/metal_backend.rs702
-rw-r--r--candle-examples/Cargo.toml1
-rw-r--r--candle-metal-kernels/src/affine.metal18
-rw-r--r--candle-metal-kernels/src/cast.metal18
-rw-r--r--candle-metal-kernels/src/indexing.metal9
-rw-r--r--candle-metal-kernels/src/lib.rs303
-rw-r--r--candle-metal-kernels/src/reduce.metal156
-rw-r--r--candle-metal-kernels/src/ternary.metal3
-rw-r--r--candle-metal-kernels/src/tests.rs158
-rw-r--r--candle-metal-kernels/src/unary.metal48
-rw-r--r--candle-metal-kernels/tmp/affine.rs (renamed from candle-metal-kernels/examples/affine.rs)1
-rw-r--r--candle-metal-kernels/tmp/binary.rs (renamed from candle-metal-kernels/examples/binary.rs)0
-rw-r--r--candle-metal-kernels/tmp/cast.rs (renamed from candle-metal-kernels/examples/cast.rs)0
-rw-r--r--candle-metal-kernels/tmp/unary.rs (renamed from candle-metal-kernels/examples/unary.rs)6
-rw-r--r--candle-nn/Cargo.toml2
-rw-r--r--candle-nn/src/ops.rs40
16 files changed, 988 insertions, 477 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index 0b72f080..12f56d50 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -4,11 +4,13 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels;
use candle_metal_kernels::Kernels;
-use core::mem;
-use half::{bf16, f16};
+use half::f16;
use metal;
-use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger};
-use std::sync::Arc;
+use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
+use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
+use std::collections::HashMap;
+use std::path::Path;
+use std::sync::{Arc, RwLock};
/// Metal related errors
#[derive(thiserror::Error, Debug)]
@@ -36,7 +38,9 @@ impl From<String> for MetalError {
pub struct MetalDevice {
device: metal::Device,
command_queue: metal::CommandQueue,
+ command_buffer: Arc<RwLock<metal::CommandBuffer>>,
kernels: Arc<candle_metal_kernels::Kernels>,
+ buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>,
}
impl std::fmt::Debug for MetalDevice {
@@ -58,10 +62,48 @@ impl MetalDevice {
self.registry_id()
}
+ pub fn metal_device(&self) -> &metal::Device {
+ &self.device
+ }
+
pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue
}
+ pub fn command_buffer(&self) -> std::sync::RwLockReadGuard<CommandBuffer> {
+ self.command_buffer.try_read().unwrap()
+ }
+
+ pub fn commit(&self) {
+ let mut old = self.command_buffer.try_write().unwrap();
+ match old.status() {
+ metal::MTLCommandBufferStatus::NotEnqueued
+ | metal::MTLCommandBufferStatus::Enqueued => {
+ old.commit();
+ let command_buffer = self.command_queue.new_command_buffer().to_owned();
+ *old = command_buffer;
+ }
+ _ => {}
+ }
+ }
+
+ pub fn wait_until_completed(&self) {
+ let mut old = self.command_buffer.try_write().unwrap();
+ match old.status() {
+ metal::MTLCommandBufferStatus::NotEnqueued
+ | metal::MTLCommandBufferStatus::Enqueued => {
+ old.commit();
+ old.wait_until_completed();
+ }
+ metal::MTLCommandBufferStatus::Committed | metal::MTLCommandBufferStatus::Scheduled => {
+ old.wait_until_completed();
+ }
+ _ => {}
+ }
+ let command_buffer = self.command_queue.new_command_buffer().to_owned();
+ *old = command_buffer;
+ }
+
pub fn kernels(&self) -> &Kernels {
&self.kernels
}
@@ -70,16 +112,107 @@ impl MetalDevice {
&self.device
}
- pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
+ pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Arc<Buffer> {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
- self.device
- .new_buffer(size, MTLResourceOptions::StorageModeManaged)
+ self._new_buffer(size, MTLResourceOptions::StorageModePrivate)
+ }
+
+ fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions) -> Arc<Buffer> {
+ let mut buffers = self.buffers.try_write().unwrap();
+ let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
+
+ for sub in &mut *subbuffers {
+ if Arc::strong_count(sub) == 1 {
+ return sub.clone();
+ }
+ }
+ let new_buffer = self.device.new_buffer(size as NSUInteger, option);
+ let new_buffer = Arc::new(new_buffer);
+ subbuffers.push(new_buffer.clone());
+ new_buffer
+ }
+
+ pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
+ self._new_buffer(size, MTLResourceOptions::StorageModeManaged)
+ }
+
+ pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
+ let size = core::mem::size_of_val(data) as NSUInteger;
+ let tmp = self.device.new_buffer_with_data(
+ data.as_ptr() as *const core::ffi::c_void,
+ size,
+ metal::MTLResourceOptions::StorageModeManaged,
+ );
+ let real = self._new_buffer(size, metal::MTLResourceOptions::StorageModePrivate);
+ {
+ let command = self.command_buffer();
+ let blit = command.new_blit_command_encoder();
+ blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
+ blit.end_encoding();
+ }
+ // This is necessary, for mmaped safetensors
+ // Because of the unsafe slice cast we're doing.
+ // The slice might not live long enough for metal
+ // To actually fill the GPU buffer.
+ // Putting this wait forces the GPU buffer to be filled
+ // with the actual data allowing the CPU storage todo
+ // deallocate properly.
+ self.wait_until_completed();
+ real
+ }
+
+ pub fn new_matrix(
+ &self,
+ (b, m, n): (NSUInteger, NSUInteger, NSUInteger),
+ size: NSUInteger,
+ type_id: u32,
+ dtype: DType,
+ ) -> Result<(Matrix, Arc<Buffer>)> {
+ let elem_count = (b * m * n) as usize;
+ let out_buffer = self.new_buffer(elem_count, dtype);
+
+ let result_descriptor =
+ MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id);
+ let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor)
+ .ok_or_else(|| {
+ MetalError::from("Failed to create matrix multiplication kernel".to_string())
+ })?;
+ Ok((result_matrix, out_buffer))
+ }
+
+ pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
+ let capture = metal::CaptureManager::shared();
+ let descriptor = metal::CaptureDescriptor::new();
+ descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
+ descriptor.set_capture_device(&self);
+ descriptor.set_output_url(path);
+
+ capture
+ .start_capture(&descriptor)
+ .map_err(MetalError::from)?;
+ Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MetalStorage {
- buffer: metal::Buffer,
+ buffer: Arc<metal::Buffer>,
+ matrices: Arc<
+ RwLock<
+ HashMap<
+ (
+ NSUInteger,
+ NSUInteger,
+ NSUInteger,
+ bool,
+ NSUInteger,
+ NSUInteger,
+ u32,
+ ),
+ Matrix,
+ >,
+ >,
+ >,
device: MetalDevice,
dtype: DType,
}
@@ -108,14 +241,23 @@ impl BackendStorage for MetalStorage {
self.dtype
);
}
+
+ let buffer = self.device.new_buffer_managed(self.buffer.length());
+ let command_buffer = self.device.command_buffer();
+ let blit = command_buffer.new_blit_command_encoder();
+ blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
+ blit.end_encoding();
+ drop(command_buffer);
+ self.device.wait_until_completed();
+
match self.dtype {
- DType::U8 => Ok(CpuStorage::U8(self.buffer.read_to_vec(length / size))),
- DType::U32 => Ok(CpuStorage::U32(self.buffer.read_to_vec(length / size))),
- DType::I64 => Ok(CpuStorage::I64(self.buffer.read_to_vec(length / size))),
- DType::F16 => Ok(CpuStorage::F16(self.buffer.read_to_vec(length / size))),
- DType::BF16 => Ok(CpuStorage::BF16(self.buffer.read_to_vec(length / size))),
- DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(length / size))),
- DType::F64 => Ok(CpuStorage::F64(self.buffer.read_to_vec(length / size))),
+ DType::U8 => Ok(CpuStorage::U8(buffer.read_to_vec(length / size))),
+ DType::U32 => Ok(CpuStorage::U32(buffer.read_to_vec(length / size))),
+ DType::I64 => Ok(CpuStorage::I64(buffer.read_to_vec(length / size))),
+ DType::F16 => Ok(CpuStorage::F16(buffer.read_to_vec(length / size))),
+ DType::BF16 => Ok(CpuStorage::BF16(buffer.read_to_vec(length / size))),
+ DType::F32 => Ok(CpuStorage::F32(buffer.read_to_vec(length / size))),
+ DType::F64 => Ok(CpuStorage::F64(buffer.read_to_vec(length / size))),
}
}
@@ -126,30 +268,48 @@ impl BackendStorage for MetalStorage {
let el = shape.elem_count();
let dtype = self.dtype;
- if layout.is_contiguous() || layout.start_offset() != 0 || dtype != DType::F32 {
- crate::bail!("Not contiguous, non-f32 affine is not implemented yet.");
+ let buffer = device.new_buffer(el, self.dtype);
+ let command_buffer = self.device.command_buffer();
+ if layout.is_contiguous() && layout.start_offset() == 0 {
+ let name = match self.dtype {
+ DType::F32 => "affine_float",
+ DType::F16 => "affine_half",
+ dtype => crate::bail!("Affine {dtype:?}"),
+ };
+ candle_metal_kernels::call_affine(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ name,
+ el,
+ &self.buffer,
+ &buffer,
+ mul as f32,
+ add as f32,
+ )
+ .map_err(MetalError::from)?;
+ } else {
+ let name = match self.dtype {
+ DType::F32 => "affine_float_strided",
+ DType::F16 => "affine_half_strided",
+ dtype => crate::bail!("Affine {dtype:?}"),
+ };
+ candle_metal_kernels::call_affine_strided(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ name,
+ layout.dims(),
+ &self.buffer,
+ layout.stride(),
+ layout.start_offset() * dtype.size_in_bytes(),
+ &buffer,
+ mul as f32,
+ add as f32,
+ )
+ .map_err(MetalError::from)?;
}
-
- let mut buffer = device.new_buffer(el, self.dtype);
- let command_buffer = self.device.command_queue.new_command_buffer();
- candle_metal_kernels::call_affine(
- &device.device,
- &command_buffer,
- &device.kernels,
- el,
- &self.buffer,
- &mut buffer,
- mul as f32,
- add as f32,
- )
- .map_err(MetalError::from)?;
- command_buffer.commit();
- command_buffer.wait_until_completed();
- return Ok(Self {
- buffer,
- device: device.clone(),
- dtype,
- });
+ Ok(Self::new(buffer, device.clone(), dtype))
}
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
@@ -163,11 +323,11 @@ impl BackendStorage for MetalStorage {
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
if !(sum_dims.len() == 1
&& sum_dims[0] == layout.shape().rank() - 1
- && layout.is_contiguous()
- && layout.start_offset() == 0)
+ && layout.stride()[sum_dims[0]] == 1)
{
- crate::bail!("Non contiguous reduce op not supported yet");
+ crate::bail!("Non last dim reduce op not supported yet");
}
+
let device = self.device.clone();
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
@@ -202,8 +362,11 @@ impl BackendStorage for MetalStorage {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let dtype = if return_index { DType::U32 } else { self.dtype };
- let mut buffer = device.new_buffer(dst_el, dtype);
- let command_buffer = self.device.command_queue.new_command_buffer();
+ if dtype == DType::U32 {
+ crate::bail!("Implement return index reduce op");
+ }
+ let buffer = device.new_buffer(dst_el, dtype);
+ let command_buffer = self.device.command_buffer();
candle_metal_kernels::call_reduce_contiguous(
&device.device,
&command_buffer,
@@ -212,17 +375,12 @@ impl BackendStorage for MetalStorage {
src_el,
dst_el,
&self.buffer,
- &mut buffer,
+ layout.start_offset() * self.dtype.size_in_bytes(),
+ &buffer,
)
.map_err(MetalError::from)?;
- command_buffer.commit();
- command_buffer.wait_until_completed();
- Ok(Self {
- buffer,
- device,
- dtype,
- })
+ Ok(Self::new(buffer, device, dtype))
}
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
@@ -233,11 +391,15 @@ impl BackendStorage for MetalStorage {
let device = self.device();
let shape = layout.shape();
let el_count = shape.elem_count();
- let mut buffer = device.new_buffer(el_count, dtype);
- let command_buffer = device.command_queue.new_command_buffer();
+ let buffer = device.new_buffer(el_count, dtype);
+ let command_buffer = device.command_buffer();
if layout.is_contiguous() {
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32",
+ (DType::U32, DType::U8) => "cast_u32_u8",
+ (DType::U8, DType::U32) => "cast_u8_u32",
+ (DType::F32, DType::F16) => "cast_f32_f16",
+ (DType::F16, DType::F32) => "cast_f16_f32",
(left, right) => crate::bail!("to dtype {left:?} - {right:?}"),
};
candle_metal_kernels::call_cast_contiguous(
@@ -247,24 +409,34 @@ impl BackendStorage for MetalStorage {
kernel_name,
el_count,
&self.buffer,
- &mut buffer,
+ layout.start_offset() * self.dtype.size_in_bytes(),
+ &buffer,
)
.map_err(MetalError::from)?;
} else {
- crate::bail!(
- "TODO Implement the kernel calling cast {:?}-{:?}",
- self.dtype,
- dtype
- );
+ let kernel_name = match (self.dtype, dtype) {
+ (DType::U32, DType::F32) => "cast_u32_f32_strided",
+ (DType::U32, DType::U8) => "cast_u32_u8_strided",
+ (DType::U8, DType::U32) => "cast_u8_u32_strided",
+ (DType::F32, DType::F16) => "cast_f32_f16_strided",
+ (DType::F16, DType::F32) => "cast_f16_f32_strided",
+ (left, right) => crate::bail!("to dtype {left:?} - {right:?}"),
+ };
+ candle_metal_kernels::call_cast_strided(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ kernel_name,
+ layout.dims(),
+ &self.buffer,
+ layout.stride(),
+ layout.start_offset() * self.dtype.size_in_bytes(),
+ &buffer,
+ )
+ .map_err(MetalError::from)?;
}
- command_buffer.commit();
- command_buffer.wait_until_completed();
- Ok(Self {
- buffer,
- device: device.clone(),
- dtype,
- })
+ Ok(Self::new(buffer, device.clone(), dtype))
}
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
@@ -272,8 +444,8 @@ impl BackendStorage for MetalStorage {
let dtype = self.dtype;
let shape = layout.shape();
let el_count = shape.elem_count();
- let mut buffer = device.new_buffer(el_count, dtype);
- let command_buffer = device.command_queue.new_command_buffer();
+ let buffer = device.new_buffer(el_count, dtype);
+ let command_buffer = device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 {
use candle_metal_kernels::unary::contiguous;
@@ -285,6 +457,25 @@ impl BackendStorage for MetalStorage {
("uneg", DType::F32) => contiguous::neg::FLOAT,
("uexp", DType::F32) => contiguous::exp::FLOAT,
("ulog", DType::F32) => contiguous::log::FLOAT,
+ ("ugelu", DType::F32) => contiguous::gelu::FLOAT,
+ ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
+ ("uerf", DType::F32) => contiguous::erf::FLOAT,
+ ("uceil", DType::F32) => contiguous::ceil::FLOAT,
+ ("ufloor", DType::F32) => contiguous::floor::FLOAT,
+ ("uround", DType::F32) => contiguous::round::FLOAT,
+ ("ucos", DType::F16) => contiguous::cos::HALF,
+ ("usin", DType::F16) => contiguous::sin::HALF,
+ ("usqr", DType::F16) => contiguous::sqr::HALF,
+ ("usqrt", DType::F16) => contiguous::sqrt::HALF,
+ ("uneg", DType::F16) => contiguous::neg::HALF,
+ ("uexp", DType::F16) => contiguous::exp::HALF,
+ ("ulog", DType::F16) => contiguous::log::HALF,
+ ("ugelu", DType::F16) => contiguous::gelu::HALF,
+ ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
+ ("uerf", DType::F16) => contiguous::erf::HALF,
+ ("uceil", DType::F16) => contiguous::ceil::HALF,
+ ("ufloor", DType::F16) => contiguous::floor::HALF,
+ ("uround", DType::F16) => contiguous::round::HALF,
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_unary_contiguous(
@@ -294,20 +485,58 @@ impl BackendStorage for MetalStorage {
kernel_name,
el_count,
&self.buffer,
- &mut buffer,
+ &buffer,
)
.map_err(MetalError::from)?;
} else {
- crate::bail!("TODO Implement the kernel calling {}", B::KERNEL);
+ use candle_metal_kernels::unary::strided;
+ let kernel_name = match (B::KERNEL, dtype) {
+ ("ucos", DType::F32) => strided::cos::FLOAT,
+ ("usin", DType::F32) => strided::sin::FLOAT,
+ ("usqr", DType::F32) => strided::sqr::FLOAT,
+ ("usqrt", DType::F32) => strided::sqrt::FLOAT,
+ ("uneg", DType::F32) => strided::neg::FLOAT,
+ ("uexp", DType::F32) => strided::exp::FLOAT,
+ ("ulog", DType::F32) => strided::log::FLOAT,
+ ("ugelu", DType::F32) => strided::gelu::FLOAT,
+ ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
+ ("uerf", DType::F32) => strided::erf::FLOAT,
+ ("uceil", DType::F32) => strided::ceil::FLOAT,
+ ("ufloor", DType::F32) => strided::floor::FLOAT,
+ ("uround", DType::F32) => strided::round::FLOAT,
+ ("ucos", DType::F16) => strided::cos::HALF,
+ ("usin", DType::F16) => strided::sin::HALF,
+ ("usqr", DType::F16) => strided::sqr::HALF,
+ ("usqrt", DType::F16) => strided::sqrt::HALF,
+ ("uneg", DType::F16) => strided::neg::HALF,
+ ("uexp", DType::F16) => strided::exp::HALF,
+ ("ulog", DType::F16) => strided::log::HALF,
+ ("ugelu", DType::F16) => strided::gelu::HALF,
+ ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
+ ("uerf", DType::F16) => strided::erf::HALF,
+ ("uceil", DType::F16) => strided::ceil::HALF,
+ ("ufloor", DType::F16) => strided::floor::HALF,
+ ("uround", DType::F16) => strided::round::HALF,
+ (name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
+ };
+ candle_metal_kernels::call_unary_strided(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ kernel_name,
+ layout.dims(),
+ &self.buffer,
+ layout.stride(),
+ layout.start_offset() * self.dtype.size_in_bytes(),
+ &buffer,
+ 0,
+ )
+ .map_err(MetalError::from)?;
}
- command_buffer.commit();
- command_buffer.wait_until_completed();
-
- Ok(Self {
- buffer,
- device: device.clone(),
- dtype,
- })
+ command_buffer.set_label("unary");
+ drop(command_buffer);
+ self.device.commit();
+ Ok(Self::new(buffer, device.clone(), dtype))
}
fn binary_impl<B: BinaryOpT>(
@@ -320,8 +549,8 @@ impl BackendStorage for MetalStorage {
let dtype = self.dtype;
let shape = lhs_l.shape();
let el_count = shape.elem_count();
- let mut buffer = device.new_buffer(el_count, dtype);
- let command_buffer = device.command_queue.new_command_buffer();
+ let buffer = device.new_buffer(el_count, dtype);
+ let command_buffer = device.command_buffer();
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
{
@@ -336,6 +565,14 @@ impl BackendStorage for MetalStorage {
("bmul", DType::F32) => contiguous::mul::FLOAT,
("div", DType::F32) => contiguous::div::FLOAT,
("bdiv", DType::F32) => contiguous::div::FLOAT,
+ ("add", DType::F16) => contiguous::add::HALF,
+ ("badd", DType::F16) => contiguous::add::HALF,
+ ("sub", DType::F16) => contiguous::sub::HALF,
+ ("bsub", DType::F16) => contiguous::sub::HALF,
+ ("mul", DType::F16) => contiguous::mul::HALF,
+ ("bmul", DType::F16) => contiguous::mul::HALF,
+ ("div", DType::F16) => contiguous::div::HALF,
+ ("bdiv", DType::F16) => contiguous::div::HALF,
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_contiguous(
@@ -346,7 +583,7 @@ impl BackendStorage for MetalStorage {
el_count,
&self.buffer,
&rhs.buffer,
- &mut buffer,
+ &buffer,
)
.map_err(MetalError::from)?;
} else {
@@ -357,6 +594,10 @@ impl BackendStorage for MetalStorage {
("bsub", DType::F32) => strided::sub::FLOAT,
("bmul", DType::F32) => strided::mul::FLOAT,
("bdiv", DType::F32) => strided::div::FLOAT,
+ ("badd", DType::F16) => strided::add::HALF,
+ ("bsub", DType::F16) => strided::sub::HALF,
+ ("bmul", DType::F16) => strided::mul::HALF,
+ ("bdiv", DType::F16) => strided::div::HALF,
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_strided(
@@ -366,23 +607,19 @@ impl BackendStorage for MetalStorage {
kernel_name,
lhs_l.dims(),
&self.buffer,
- &lhs_l.stride(),
+ lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&rhs.buffer,
- &rhs_l.stride(),
+ rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
- &mut buffer,
+ &buffer,
)
.map_err(MetalError::from)?;
}
- command_buffer.commit();
- command_buffer.wait_until_completed();
-
- Ok(Self {
- buffer,
- device: device.clone(),
- dtype,
- })
+ command_buffer.set_label("binary");
+ drop(command_buffer);
+ self.device.commit();
+ Ok(Self::new(buffer, device.clone(), dtype))
}
fn where_cond(
@@ -398,14 +635,22 @@ impl BackendStorage for MetalStorage {
let dims = shape.dims();
let el = shape.elem_count();
let dtype = t.dtype;
- let mut buffer = self.device.new_buffer(el, dtype);
- let command_buffer = self.device.command_queue.new_command_buffer();
+ let buffer = self.device.new_buffer(el, dtype);
+ let command_buffer = self.device.command_buffer();
+ if t.dtype() != f.dtype() {
+ crate::bail!("Invalid ternary different dtypes for values");
+ }
+ let name = match (self.dtype, t.dtype()) {
+ (DType::U8, DType::F32) => "where_u8_f32",
+ (DType::U8, DType::F16) => "where_u8_f16",
+ (left, right) => crate::bail!("Ternary {left:?} - {right:?} not implemented"),
+ };
candle_metal_kernels::call_where_cond_strided(
&device.device,
&command_buffer,
&device.kernels,
- "where_u8_f32",
- &dims,
+ name,
+ dims,
&self.buffer,
(
layout.stride(),
@@ -415,16 +660,10 @@ impl BackendStorage for MetalStorage {
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
&f.buffer,
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
- &mut buffer,
+ &buffer,
)
.map_err(MetalError::from)?;
- command_buffer.commit();
- command_buffer.wait_until_completed();
- Ok(Self {
- buffer,
- device,
- dtype,
- })
+ Ok(Self::new(buffer, device, dtype))
}
fn conv1d(
@@ -513,12 +752,13 @@ impl BackendStorage for MetalStorage {
let dst_el = ids_el * left_size * right_size;
let dtype = self.dtype;
let device = self.device();
- let mut buffer = device.new_buffer(dst_el, dtype);
+ let buffer = device.new_buffer(dst_el, dtype);
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "is_u32_f32",
+ (DType::U32, DType::F16) => "is_u32_f16",
(left, right) => crate::bail!("index select metal {left:?} {right:?}"),
};
- let command_buffer = self.device.command_queue.new_command_buffer();
+ let command_buffer = self.device.command_buffer();
candle_metal_kernels::call_index_select(
&device.device,
&command_buffer,
@@ -529,16 +769,10 @@ impl BackendStorage for MetalStorage {
dim,
&self.buffer,
&ids.buffer,
- &mut buffer,
+ &buffer,
)
.map_err(MetalError::from)?;
- command_buffer.commit();
- command_buffer.wait_until_completed();
- Ok(Self {
- buffer,
- device: device.clone(),
- dtype,
- })
+ Ok(Self::new(buffer, device.clone(), dtype))
}
fn index_add(
@@ -561,11 +795,18 @@ impl BackendStorage for MetalStorage {
rhs_l: &Layout,
) -> Result<Self> {
// Create descriptors
- use metal::mps::matrix::*;
- let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32;
- let size = core::mem::size_of::<f32>() as NSUInteger;
- let elem_count = b * m * n;
+ let (type_id, size) = match self.dtype {
+ DType::F32 => (
+ metal::mps::MPS_FLOATBIT_ENCODING | 32,
+ core::mem::size_of::<f32>() as NSUInteger,
+ ),
+ DType::F16 => (
+ metal::mps::MPS_FLOATBIT_ENCODING | 16,
+ core::mem::size_of::<f16>() as NSUInteger,
+ ),
+ dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
+ };
let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride();
@@ -596,39 +837,30 @@ impl BackendStorage for MetalStorage {
mnk: (m, n, k),
})?
};
-
let b = b as NSUInteger;
let m = m as NSUInteger;
let n = n as NSUInteger;
let k = k as NSUInteger;
- let left_descriptor = if transpose_left {
- MatrixDescriptor::init_single(k, m, m * size, type_id)
- } else {
- MatrixDescriptor::init_single(m, k, k * size, type_id)
- };
- let right_descriptor = if transpose_right {
- MatrixDescriptor::init_single(n, k, k * size, type_id)
- } else {
- MatrixDescriptor::init_single(k, n, n * size, type_id)
- };
- let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
-
- // Create matrix objects
- let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, 0, &left_descriptor)
- .ok_or_else(|| {
- MetalError::from("Failed to create matrix multiplication kernel".to_string())
- })?;
- let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, 0, &right_descriptor)
- .ok_or_else(|| {
- MetalError::from("Failed to create matrix multiplication kernel".to_string())
- })?;
+ let left_matrix = self.matrix(
+ (b, m, k),
+ transpose_left,
+ size,
+ lhs_l.start_offset() as NSUInteger * size,
+ type_id,
+ )?;
+ let right_matrix = rhs.matrix(
+ (b, k, n),
+ transpose_right,
+ size,
+ rhs_l.start_offset() as NSUInteger * size,
+ type_id,
+ )?;
+ let (result_matrix, out_buffer) =
+ self.device
+ .new_matrix((b, m, n), size, type_id, self.dtype)?;
- let out_buffer = self.device.new_buffer(elem_count, self.dtype);
- let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor)
- .ok_or_else(|| {
- MetalError::from("Failed to create matrix multiplication kernel".to_string())
- })?;
+ let command_buffer = self.device.command_buffer();
let alpha = 1.0f64;
let beta = 0.0f64;
@@ -647,70 +879,112 @@ impl BackendStorage for MetalStorage {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
- matrix_multiplication.set_batch_size(b);
-
// Encode kernel to command buffer
- let command_buffer = self.device.command_queue.new_command_buffer();
matrix_multiplication.encode_to_command_buffer(
- command_buffer,
+ &command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
- command_buffer.commit();
- command_buffer.wait_until_completed();
+ command_buffer.set_label("matmul");
+ drop(command_buffer);
+ self.device.commit();
- Ok(Self {
- buffer: out_buffer,
- device: self.device.clone(),
- dtype: self.dtype(),
- })
+ Ok(Self::new(out_buffer, self.device.clone(), self.dtype()))
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
- let src_shape = src_l.shape();
- let el_count = src_shape.elem_count();
- if el_count == 0 {
- return Ok(());
+ let command_buffer = self.device.command_buffer();
+ if src_l.is_contiguous() && self.dtype == dst.dtype() {
+ command_buffer.set_label("copy_contiguous");
+ let blit = command_buffer.new_blit_command_encoder();
+ let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
+ let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
+ blit.copy_from_buffer(
+ &self.buffer,
+ src_offset,
+ dst.buffer(),
+ dst_offset,
+ self.buffer.length() - src_offset,
+ );
+ blit.end_encoding();
+ } else {
+ let src_shape = src_l.shape();
+ let el_count = src_shape.elem_count();
+ if el_count == 0 {
+ return Ok(());
+ }
+ let kernel_name = match self.dtype {
+ DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
+ DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
+ DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
+ DType::U32 => candle_metal_kernels::unary::strided::copy::U32,
+ DType::U8 => candle_metal_kernels::unary::strided::copy::U8,
+ dtype => crate::bail!("copy_strided not implemented for {dtype:?}"),
+ };
+ candle_metal_kernels::call_unary_strided(
+ &self.device.device,
+ &command_buffer,
+ &self.device.kernels,
+ kernel_name,
+ src_l.dims(),
+ &self.buffer,
+ src_l.stride(),
+ src_l.start_offset() * self.dtype.size_in_bytes(),
+ &dst.buffer,
+ dst_offset * dst.dtype.size_in_bytes(),
+ )
+ .map_err(MetalError::from)?;
+ command_buffer.set_label("copy_strided");
}
- let command_buffer = self.device.command_queue.new_command_buffer();
- let kernel_name = match self.dtype {
- DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
- DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
- DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
- dtype => crate::bail!("copy_strided not implemented for {dtype:?}"),
- };
- candle_metal_kernels::call_unary_strided(
- &self.device.device,
- &command_buffer,
- &self.device.kernels,
- kernel_name,
- src_l.dims(),
- &self.buffer,
- &src_l.stride(),
- src_l.start_offset() * self.dtype.size_in_bytes(),
- &mut dst.buffer,
- dst_offset,
- )
- .map_err(MetalError::from)?;
- command_buffer.commit();
- command_buffer.wait_until_completed();
+ drop(command_buffer);
+ self.device.commit();
Ok(())
}
}
impl MetalStorage {
- pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self {
+ pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self {
+ let matrices = Arc::new(RwLock::new(HashMap::new()));
Self {
buffer,
device,
dtype,
+ matrices,
}
}
pub fn buffer(&self) -> &Buffer {
&self.buffer
}
+
+ fn matrix(
+ &self,
+ (b, m, n): (NSUInteger, NSUInteger, NSUInteger),
+ transpose: bool,
+ size: NSUInteger,
+ offset: NSUInteger,
+ type_id: u32,
+ ) -> Result<Matrix> {
+ let key = (b, m, n, transpose, size, offset, type_id);
+
+ let mut matrices = self.matrices.try_write().unwrap();
+ if let Some(matrix) = matrices.get(&key) {
+ Ok(matrix.clone())
+ } else {
+ let descriptor = if transpose {
+ MatrixDescriptor::init_multiple(n, m, b, m * size, m * n * size, type_id)
+ } else {
+ MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id)
+ };
+ let matrix = Matrix::init_with_buffer_descriptor(&self.buffer, offset, &descriptor)
+ .ok_or_else(|| {
+ MetalError::from("Failed to create matrix multiplication kernel".to_string())
+ })?;
+ matrices.insert(key, matrix.clone());
+ Ok(matrix)
+ }
+ }
}
impl BackendDevice for MetalDevice {
@@ -720,10 +994,14 @@ impl BackendDevice for MetalDevice {
let device = metal::Device::all().swap_remove(ordinal);
let command_queue = device.new_command_queue();
+ let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
let kernels = Arc::new(Kernels::new());
+ let buffers = Arc::new(RwLock::new(HashMap::new()));
Ok(Self {
device,
command_queue,
+ command_buffer,
+ buffers,
kernels,
})
}
@@ -743,9 +1021,8 @@ impl BackendDevice for MetalDevice {
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
- // TODO Is there a faster way ?
- let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?;
- self.storage_from_cpu_storage(&cpu_storage)
+ let buffer = self.new_buffer(shape.elem_count(), dtype);
+ Ok(MetalStorage::new(buffer, self.clone(), dtype))
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
@@ -755,49 +1032,20 @@ impl BackendDevice for MetalDevice {
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
- let option = metal::MTLResourceOptions::StorageModeManaged;
let buffer = match storage {
- CpuStorage::U8(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<u8>()) as NSUInteger,
- option,
- ),
- CpuStorage::U32(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<u32>()) as NSUInteger,
- option,
- ),
- CpuStorage::I64(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<i64>()) as NSUInteger,
- option,
- ),
- CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<bf16>()) as NSUInteger,
- option,
- ),
- CpuStorage::F16(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<f16>()) as NSUInteger,
- option,
- ),
- CpuStorage::F32(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<f32>()) as NSUInteger,
- option,
- ),
- CpuStorage::F64(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<f64>()) as NSUInteger,
- option,
- ),
+ CpuStorage::U8(storage) => self.new_buffer_with_data(storage),
+ CpuStorage::U32(storage) => self.new_buffer_with_data(storage),
+ CpuStorage::I64(storage) => self.new_buffer_with_data(storage),
+ CpuStorage::BF16(storage) => self.new_buffer_with_data(storage),
+ CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
+ CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
+ CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
};
- Ok(Self::Storage {
- buffer,
- device: self.clone(),
- dtype: storage.dtype(),
- })
+ Ok(Self::Storage::new(
+ buffer.into(),
+ self.clone(),
+ storage.dtype(),
+ ))
}
fn rand_uniform(
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index 38d26ead..adfa529e 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -57,6 +57,7 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
+metal = ["candle/metal", "candle-nn/metal"]
[[example]]
name = "llama_multiprocess"
diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal
index e5f0a841..a08bfbc0 100644
--- a/candle-metal-kernels/src/affine.metal
+++ b/candle-metal-kernels/src/affine.metal
@@ -33,6 +33,24 @@ kernel void FN_NAME( \
const TYPENAME a = TYPENAME(add); \
output[id] = input[id] * m + a; \
} \
+kernel void FN_NAME##_strided( \
+ constant size_t &dim, \
+ constant size_t &num_dims, \
+ constant size_t *dims, \
+ constant size_t *strides, \
+ constant float &mul, \
+ constant float &add, \
+ device const TYPENAME *input, \
+ device TYPENAME *output, \
+ uint id [[ thread_position_in_grid ]] \
+) { \
+ if (id >= dim) { \
+ return; \
+ } \
+ const TYPENAME m = TYPENAME(mul); \
+ const TYPENAME a = TYPENAME(add); \
+ output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \
+} \
AFFINE(affine_float, float)
AFFINE(affine_half, half)
diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal
index d1788253..4398e9d4 100644
--- a/candle-metal-kernels/src/cast.metal
+++ b/candle-metal-kernels/src/cast.metal
@@ -23,12 +23,12 @@ kernel void FN_NAME( \
constant size_t &dim, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
- uint thread_position_in_grid [[ thread_position_in_grid ]] \
+ uint tid [[ thread_position_in_grid ]] \
) { \
- if (thread_position_in_grid >= dim) { \
+ if (tid >= dim) { \
return; \
} \
- output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \
+ output[tid] = RIGHT_TYPENAME(input[tid]); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@@ -37,15 +37,19 @@ kernel void FN_NAME_STRIDED( \
constant size_t *strides, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
- uint i [[ thread_position_in_grid ]] \
+ uint tid [[ thread_position_in_grid ]] \
) { \
- if (i >= dim) { \
+ if (tid >= dim) { \
return; \
} \
- output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
+ output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
} \
-CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
+CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
+CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
+CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
+CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
+CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
#if __METAL_VERSION__ >= 310
#endif
diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal
index 444fa322..312b27c7 100644
--- a/candle-metal-kernels/src/indexing.metal
+++ b/candle-metal-kernels/src/indexing.metal
@@ -16,16 +16,16 @@ kernel void NAME( \
if (gid >= dst_size) { \
return; \
} \
- const size_t id_i = gid / right_size / left_size; \
+ const size_t id_i = (gid / right_size) % ids_size; \
+ const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
const size_t right_rank_i = gid % right_size; \
- const size_t left_rank_i = gid % left_size; \
+ const size_t left_rank_i = gid / right_size / ids_size; \
/* \
// Force prevent out of bounds indexing \
// since there doesn't seem to be a good way to force crash \
// No need to check for zero we're only allowing unsized. \
*/ \
- const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
- const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \
+ const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
output[gid] = input[src_i]; \
}
@@ -75,6 +75,7 @@ kernel void FN_NAME( \
INDEX_OP(is_u32_f32, uint, float)
+INDEX_OP(is_u32_f16, uint, half)
#if __METAL_VERSION__ >= 310
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 5a6bd41b..a0b852a4 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1,6 +1,6 @@
use metal::{
- Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor,
- ComputePipelineState, Device, Function, Library, MTLSize,
+ Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
+ Device, Function, Library, MTLSize,
};
use std::collections::HashMap;
use std::ffi::c_void;
@@ -59,8 +59,8 @@ impl<T> EncoderParam for &[T] {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
- (core::mem::size_of::<T>() * data.len()) as u64,
- data.as_ptr() as *const T as *const c_void,
+ core::mem::size_of_val(data) as u64,
+ data.as_ptr() as *const c_void,
);
}
}
@@ -111,13 +111,7 @@ macro_rules! ops{
($($name:ident),+) => {
pub mod contiguous {
- #[derive(Clone, Copy)]
- pub struct Kernel(pub(crate) &'static str);
- impl std::fmt::Display for Kernel {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.0)
- }
- }
+ pub struct Kernel(pub &'static str);
$(
pub mod $name {
use super::Kernel;
@@ -126,16 +120,18 @@ macro_rules! ops{
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
}
)+
+ pub mod copy {
+ use super::Kernel;
+ pub const FLOAT: Kernel = Kernel("copy_float");
+ pub const HALF: Kernel = Kernel("copy_half");
+ pub const BFLOAT: Kernel = Kernel("copy_bfloat");
+ pub const U32: Kernel = Kernel("copy_u32");
+ pub const U8: Kernel = Kernel("copy_u8");
+ }
}
pub mod strided {
- #[derive(Clone, Copy)]
- pub struct Kernel(pub(crate) &'static str);
- impl std::fmt::Display for Kernel {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.0)
- }
- }
+ pub struct Kernel(pub &'static str);
$(
pub mod $name {
use super::Kernel;
@@ -144,12 +140,20 @@ macro_rules! ops{
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
}
)+
+ pub mod copy {
+ use super::Kernel;
+ pub const FLOAT: Kernel = Kernel("copy_float_strided");
+ pub const HALF: Kernel = Kernel("copy_half_strided");
+ pub const BFLOAT: Kernel = Kernel("copy_bfloat_strided");
+ pub const U32: Kernel = Kernel("copy_u32_strided");
+ pub const U8: Kernel = Kernel("copy_u8_strided");
+ }
}
};
}
pub mod unary {
- ops!(cos, sin, exp, sqr, sqrt, neg, copy, log);
+ ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf);
}
pub mod binary {
ops!(add, sub, mul, div);
@@ -161,8 +165,12 @@ pub enum MetalKernelError {
LockError(String),
#[error("Error while loading library: {0}")]
LoadLibraryError(String),
- #[error("Error while loading function: {0}")]
+ #[error("Error while loading function: {0:?}")]
LoadFunctionError(String),
+ #[error("Failed to create compute function")]
+ FailedToCreateComputeFunction,
+ #[error("Failed to create pipeline")]
+ FailedToCreatePipeline(String),
}
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
@@ -173,19 +181,22 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
type KernelMap<T> = HashMap<&'static str, T>;
type Libraries = HashMap<Source, Library>;
-type Functions = KernelMap<Function>;
+type Pipelines = KernelMap<ComputePipelineState>;
#[derive(Debug, Default)]
pub struct Kernels {
libraries: RwLock<Libraries>,
- funcs: RwLock<Functions>,
+ pipelines: RwLock<Pipelines>,
}
impl Kernels {
pub fn new() -> Self {
let libraries = RwLock::new(Libraries::new());
- let funcs = RwLock::new(Functions::new());
- Self { libraries, funcs }
+ let pipelines = RwLock::new(Pipelines::new());
+ Self {
+ libraries,
+ pipelines,
+ }
}
fn get_library_source(&self, source: Source) -> &'static str {
@@ -218,22 +229,43 @@ impl Kernels {
}
}
- pub fn load_function(
+ fn load_function(
&self,
device: &Device,
source: Source,
name: &'static str,
) -> Result<Function, MetalKernelError> {
- let mut funcs = self.funcs.write()?;
- if let Some(func) = funcs.get(name) {
- Ok(func.clone())
+ let func = self
+ .load_library(device, source)?
+ .get_function(name, None)
+ .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
+ Ok(func)
+ // let mut funcs = self.funcs.write()?;
+ // if let Some(func) = funcs.get(name) {
+ // Ok(func.clone())
+ // } else {
+ // funcs.insert(name, func.clone());
+ // Ok(func)
+ // }
+ }
+
+ pub fn load_pipeline(
+ &self,
+ device: &Device,
+ source: Source,
+ name: &'static str,
+ ) -> Result<ComputePipelineState, MetalKernelError> {
+ let mut pipelines = self.pipelines.write()?;
+ if let Some(pipeline) = pipelines.get(name) {
+ Ok(pipeline.clone())
} else {
- let func = self
- .load_library(device, source)?
- .get_function(name, None)
- .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
- funcs.insert(name, func.clone());
- Ok(func)
+ let func = self.load_function(device, source, name)?;
+ let pipeline = device
+ .new_compute_pipeline_state_with_function(&func)
+ .map_err(|e| MetalKernelError::FailedToCreatePipeline(e.to_string()))?;
+ pipelines.insert(name, pipeline.clone());
+
+ Ok(pipeline)
}
}
}
@@ -246,18 +278,9 @@ pub fn call_unary_contiguous(
kernel_name: unary::contiguous::Kernel,
length: usize,
input: &Buffer,
- output: &mut Buffer,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
-
+ let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
@@ -279,18 +302,10 @@ pub fn call_unary_strided(
input: &Buffer,
strides: &[usize],
offset: usize,
- output: &mut Buffer,
+ output: &Buffer,
output_offset: usize,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Unary, name.0)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
+ let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
@@ -326,17 +341,9 @@ pub fn call_binary_contiguous(
length: usize,
left: &Buffer,
right: &Buffer,
- output: &mut Buffer,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Binary, kernel_name.0)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
+ let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
@@ -363,17 +370,9 @@ pub fn call_binary_strided(
right_input: &Buffer,
right_strides: &[usize],
right_offset: usize,
- output: &mut Buffer,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Binary, name.0)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
+ let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
@@ -411,22 +410,52 @@ pub fn call_cast_contiguous(
kernel_name: &'static str,
length: usize,
input: &Buffer,
- output: &mut Buffer,
+ input_offset: usize,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Cast, kernel_name)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
+ let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(encoder, (length, (input, input_offset), output));
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_cast_strided(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ kernel_name: &'static str,
+ shape: &[usize],
+ input: &Buffer,
+ input_strides: &[usize],
+ input_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (length, input, output));
+ let length: usize = shape.iter().product();
+
+ set_params!(
+ encoder,
+ (
+ length,
+ shape.len(),
+ shape,
+ input_strides,
+ (input, input_offset),
+ output
+ )
+ );
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
@@ -435,7 +464,6 @@ pub fn call_cast_contiguous(
Ok(())
}
-#[allow(clippy::too_many_arguments)]
pub fn call_reduce_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -444,24 +472,19 @@ pub fn call_reduce_contiguous(
length: usize,
out_length: usize,
input: &Buffer,
- output: &mut Buffer,
+ input_offset: usize,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
-
+ let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
- set_params!(encoder, (length, elements_to_sum, input, output));
+ set_params!(
+ encoder,
+ (length, elements_to_sum, (input, input_offset), output)
+ );
let thread_group_count = MTLSize {
width: out_length as u64,
@@ -495,18 +518,9 @@ pub fn call_last_softmax(
length: usize,
elements_to_sum: usize,
input: &Buffer,
- output: &mut Buffer,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Reduce, kernel_name)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
-
+ let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
@@ -542,21 +556,14 @@ pub fn call_affine(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
+ name: &'static str,
size: usize,
input: &Buffer,
- output: &mut Buffer,
+ output: &Buffer,
mul: f32,
add: f32,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Affine, "affine_float")?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
+ let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
@@ -570,6 +577,45 @@ pub fn call_affine(
}
#[allow(clippy::too_many_arguments)]
+pub fn call_affine_strided(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ shape: &[usize],
+ input: &Buffer,
+ input_stride: &[usize],
+ input_offset: usize,
+ output: &Buffer,
+ mul: f32,
+ add: f32,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
+ let size: usize = shape.iter().product();
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (
+ size,
+ shape.len(),
+ shape,
+ input_stride,
+ mul,
+ add,
+ (input, input_offset),
+ output
+ )
+ );
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
pub fn call_where_cond_strided(
device: &Device,
command_buffer: &CommandBufferRef,
@@ -582,17 +628,9 @@ pub fn call_where_cond_strided(
(left_stride, left_offset): (&[usize], usize),
right: &Buffer,
(right_stride, right_offset): (&[usize], usize),
- output: &mut Buffer,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
- let func = kernels.load_function(device, Source::Ternary, name)?;
- let pipeline_state_descriptor = ComputePipelineDescriptor::new();
- pipeline_state_descriptor.set_compute_function(Some(&func));
-
- let pipeline = device
- .new_compute_pipeline_state_with_function(
- pipeline_state_descriptor.compute_function().unwrap(),
- )
- .unwrap();
+ let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
@@ -634,17 +672,14 @@ pub fn call_index_select(
dim: usize,
input: &Buffer,
ids: &Buffer,
- output: &mut Buffer,
+ output: &Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let src_dim_size = shape[dim];
let dst_el = ids_size * left_size * right_size;
- let func = kernels.load_function(device, Source::Indexing, name)?;
- let pipeline = device
- .new_compute_pipeline_state_with_function(&func)
- .unwrap();
+ let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder();
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal
index c6984474..867877fb 100644
--- a/candle-metal-kernels/src/reduce.metal
+++ b/candle-metal-kernels/src/reduce.metal
@@ -1,6 +1,8 @@
#include <metal_stdlib>
using namespace metal;
+#define MAX(x, y) ((x) > (y) ? (x) : (y))
+
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
@@ -16,18 +18,18 @@ METAL_FUNC uint get_strided_index(
return strided_i;
}
-constant int THREADGROUP_SIZE = 256;
+constant int THREADGROUP_SIZE = 1024;
-# define REDUCE(FN, NAME, TYPENAME) \
+# define REDUCE(FN, NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
- device const TYPENAME *src, \
- device TYPENAME *dst, \
+ device const T *src, \
+ device T *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
- uint blockDim [[ threads_per_threadgroup ]] \
+ uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup float shared_memory[THREADGROUP_SIZE]; \
@@ -45,10 +47,10 @@ kernel void NAME( \
// TODO: Fast version for the contiguous case. \
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
*/ \
- TYPENAME x = shared_memory[tid]; \
- TYPENAME y = src[idx]; \
+ T x = shared_memory[tid]; \
+ T y = src[idx]; \
shared_memory[tid] = FN; \
- idx += blockDim; \
+ idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
@@ -56,10 +58,10 @@ kernel void NAME( \
/* \
// reduction in shared memory \
*/ \
- for (uint s = blockDim / 2; s > 0; s >>= 1) { \
+ for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
- TYPENAME x = shared_memory[tid]; \
- TYPENAME y = shared_memory[tid + s]; \
+ T x = shared_memory[tid]; \
+ T y = shared_memory[tid + s]; \
shared_memory[tid] = FN; \
} \
threadgroup_barrier(mem_flags::mem_none); \
@@ -68,72 +70,74 @@ kernel void NAME( \
dst[dst_id] = shared_memory[0]; \
} \
-kernel void softmax_float(
- constant size_t &src_numel,
- constant size_t &el_to_sum_per_block,
- device const float *src,
- device float *dst,
- uint id [[ thread_position_in_grid ]],
- uint tid [[ thread_index_in_threadgroup ]],
- uint dst_id [[ threadgroup_position_in_grid ]],
- uint blockDim [[ threads_per_threadgroup ]]
-) {
-
- threadgroup float shared_memory[THREADGROUP_SIZE];
-
- shared_memory[tid] = -INFINITY;
- // Elements summed in this block range from dst_id * el_to_sum_per_block
- // to (dst_id + 1) * el_to_sum_per_block.
- size_t start_idx = dst_id * el_to_sum_per_block;
- size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
- size_t idx = start_idx + tid;
-
- while (idx < stop_idx) {
- // TODO: Fast version for the contiguous case.
- shared_memory[tid] = max(shared_memory[tid], src[idx]);
- idx += blockDim;
- }
-
- threadgroup_barrier(mem_flags::mem_none);
-
- // reduction in shared memory
- for (uint s = blockDim / 2; s > 0; s >>= 1) {
- if (tid < s) {
- shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]);
- }
- threadgroup_barrier(mem_flags::mem_none);
- }
-
- float max = shared_memory[0];
-
- shared_memory[tid] = 0;
-
- // Restart
- idx = start_idx + tid;
- while (idx < stop_idx) {
- // TODO: Fast version for the contiguous case.
- const float val = exp(src[idx] - max);
- dst[idx] = val;
- shared_memory[tid] += val;
- idx += blockDim;
- }
- // reduction in shared memory
- for (uint s = blockDim / 2; s > 0; s >>= 1) {
- if (tid < s) {
- shared_memory[tid] += shared_memory[tid + s];
- }
- threadgroup_barrier(mem_flags::mem_none);
- }
-
- const float inv_acc = 1/shared_memory[0];
- idx = start_idx + tid;
- while (idx < stop_idx) {
- dst[idx] *= inv_acc;
- idx += blockDim;
- }
-}
-
REDUCE(x + y, fast_sum_float, float)
REDUCE(x * y, fast_mul_float, float)
REDUCE(max(x, y), fast_max_float, float)
+
+#define SOFTMAX(NAME, T) \
+kernel void NAME( \
+ constant size_t &src_numel, \
+ constant size_t &el_to_sum_per_block, \
+ device const T *src, \
+ device T *dst, \
+ \
+ uint id [[ thread_position_in_grid ]], \
+ uint tid [[ thread_index_in_threadgroup ]], \
+ uint dst_id [[ threadgroup_position_in_grid ]], \
+ uint block_dim [[ threads_per_threadgroup ]] \
+) { \
+ threadgroup float shared_memory[THREADGROUP_SIZE]; \
+ shared_memory[tid] = -INFINITY; \
+ size_t start_idx = dst_id * el_to_sum_per_block; \
+ size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
+ size_t idx = start_idx + tid; \
+ \
+ threadgroup_barrier(mem_flags::mem_threadgroup); \
+ \
+ while (idx < stop_idx) { \
+ shared_memory[tid] = MAX(shared_memory[tid], src[idx]); \
+ idx += block_dim; \
+ } \
+ \
+ threadgroup_barrier(mem_flags::mem_threadgroup); \
+ \
+ for (uint s = block_dim / 2; s > 0; s >>= 1) { \
+ if (tid < s) { \
+ shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
+ } \
+ } \
+ \
+ threadgroup_barrier(mem_flags::mem_threadgroup); \
+ \
+ float _max = shared_memory[0]; \
+ \
+ shared_memory[tid] = 0; \
+ \
+ idx = start_idx + tid; \
+ while (idx < stop_idx) { \
+ const T val = T(exp(src[idx] - _max)); \
+ dst[idx] = val; \
+ shared_memory[tid] += val; \
+ idx += block_dim; \
+ } \
+ for (uint s = block_dim / 2; s > 0; s >>= 1) { \
+ if (tid < s) { \
+ shared_memory[tid] += shared_memory[tid + s]; \
+ } \
+ threadgroup_barrier(mem_flags::mem_threadgroup); \
+ } \
+ \
+ const T inv_acc = T(1/shared_memory[0]); \
+ idx = start_idx + tid; \
+ while (idx < stop_idx) { \
+ dst[idx] *= inv_acc; \
+ idx += block_dim; \
+ } \
+} \
+
+SOFTMAX(softmax_float, float)
+SOFTMAX(softmax_half, half)
+#if __METAL_VERSION__ >= 310
+SOFTMAX(softmax_bfloat, bfloat)
+#endif
diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal
index 0945b355..1f9cb38a 100644
--- a/candle-metal-kernels/src/ternary.metal
+++ b/candle-metal-kernels/src/ternary.metal
@@ -32,6 +32,9 @@ kernel void FN_NAME( \
device TYPENAME *out ,\
uint i [[ thread_position_in_grid ]] \
) { \
+ if (i >= numel){ \
+ return; \
+ } \
uint strided_i = get_strided_index(i, num_dims, dims, strides); \
uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \
uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index 2330d48d..66dc8d01 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -1,5 +1,5 @@
use super::*;
-use half::f16;
+use half::{bf16, f16};
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
@@ -23,13 +23,18 @@ fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
}
+fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
+ let b = 10f32.powi(digits);
+ v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
+}
+
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
+ let output = new_buffer(&device, v);
call_unary_contiguous(
&device,
command_buffer,
@@ -37,7 +42,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
name,
v.len(),
&input,
- &mut output,
+ &output,
)
.unwrap();
command_buffer.commit();
@@ -53,7 +58,7 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
let options = MTLResourceOptions::StorageModeManaged;
let left = new_buffer(&device, x);
let right = new_buffer(&device, y);
- let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
+ let output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
call_binary_contiguous(
&device,
command_buffer,
@@ -62,7 +67,7 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
x.len(),
&left,
&right,
- &mut output,
+ &output,
)
.unwrap();
command_buffer.commit();
@@ -81,7 +86,7 @@ fn run_strided<T: Clone>(
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
+ let output = new_buffer(&device, v);
let kernels = Kernels::new();
call_unary_strided(
&device,
@@ -92,7 +97,7 @@ fn run_strided<T: Clone>(
&input,
strides,
offset,
- &mut output,
+ &output,
0,
)
.unwrap();
@@ -220,7 +225,9 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
+ let options = MTLResourceOptions::StorageModeManaged;
+ let size = (v.len() * std::mem::size_of::<U>()) as u64;
+ let output = device.new_buffer(size, options);
call_cast_contiguous(
&device,
@@ -229,7 +236,8 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
name,
v.len(),
&input,
- &mut output,
+ 0,
+ &output,
)
.unwrap();
command_buffer.commit();
@@ -245,11 +253,17 @@ fn cast_u32_f32() {
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
+ let v = vec![1.0f32, 2.0, 3.0];
+ let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
+ let results: Vec<f32> = cast(&input, "cast_f16_f32");
+ assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
+
let v = vec![1.0f32; 10_000];
- let results = run(&v, unary::contiguous::cos::FLOAT);
- let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
- assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
- assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
+ let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
+ let results: Vec<f32> = cast(&input, "cast_f16_f32");
+ assert_eq!(results.len(), 10_000);
+ assert_eq!(&results[..10], vec![1.0f32; 10]);
+ assert_eq!(results, vec![1.0f32; 10_000]);
}
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
@@ -259,7 +273,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
+ let output = new_buffer(&device, v);
let size = v.len();
@@ -267,9 +281,45 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
&device,
command_buffer,
&kernels,
+ "affine_float",
size,
&input,
- &mut output,
+ &output,
+ mul as f32,
+ add as f32,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ output.read_to_vec::<T>(v.len())
+}
+
+fn _run_affine_strided<T: Clone>(
+ v: &[T],
+ shape: &[usize],
+ strides: &[usize],
+ mul: f64,
+ add: f64,
+) -> Vec<T> {
+ let device = device();
+ let kernels = Kernels::new();
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+
+ let input = new_buffer(&device, v);
+ let output = new_buffer(&device, v);
+
+ call_affine_strided(
+ &device,
+ command_buffer,
+ &kernels,
+ "affine_float",
+ shape,
+ &input,
+ strides,
+ 0,
+ &output,
mul as f32,
add as f32,
)
@@ -295,6 +345,16 @@ fn affine() {
assert_eq!(result, vec![2.6; 40_000]);
}
+// #[test]
+// fn affine_strided() {
+// let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
+// let mul = 1.5;
+// let add = 1.1;
+// let result = run_affine_(&input, mul, add);
+// assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
+
+// }
+
#[test]
fn index_select() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
@@ -313,7 +373,26 @@ fn index_select() {
result,
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
);
+}
+
+#[test]
+fn index_select_f16() {
+ let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
+ .into_iter()
+ .map(|x| f16::from_f32(x))
+ .collect();
+ let shape = [5, 2];
+ let ids = [0u32, 4, 2];
+ let dim = 0;
+ let result = run_index_select(&embedding, &shape, &ids, dim);
+ assert_eq!(
+ approx_f16(result, 4),
+ vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
+ );
+}
+#[test]
+fn index_select_dim1() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
let ids = [0u32, 1, 0];
@@ -321,7 +400,7 @@ fn index_select() {
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(
result,
- vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
+ vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
);
}
@@ -341,20 +420,26 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let dst_el = ids.len() * left_size * right_size;
- let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
+ let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
+
+ let name = match core::mem::size_of::<T>() {
+ 4 => "is_u32_f32",
+ 2 => "is_u32_f16",
+ _ => unimplemented!(),
+ };
let kernels = Kernels::new();
call_index_select(
&device,
&command_buffer,
&kernels,
- "is_u32_f32",
+ name,
shape,
ids.len(),
dim,
&embeddings_buffer,
&ids_buffer,
- &mut dst_buffer,
+ &dst_buffer,
)
.unwrap();
@@ -451,7 +536,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
let input = new_buffer(&device, v);
let options = MTLResourceOptions::StorageModeManaged;
- let mut output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
+ let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
call_reduce_contiguous(
&device,
command_buffer,
@@ -460,7 +545,8 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
v.len(),
out_length,
&input,
- &mut output,
+ 0,
+ &output,
)
.unwrap();
command_buffer.commit();
@@ -475,7 +561,7 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
+ let output = new_buffer(&device, v);
call_last_softmax(
&device,
command_buffer,
@@ -484,7 +570,7 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
v.len(),
last_dim,
&input,
- &mut output,
+ &output,
)
.unwrap();
command_buffer.commit();
@@ -536,6 +622,28 @@ fn softmax() {
approx(results, 4),
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
);
+
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
+ .iter()
+ .map(|v| f16::from_f32(*v))
+ .collect::<Vec<_>>();
+ let last_dim = 6;
+ let results = run_softmax(&v, last_dim, "softmax_half");
+ assert_eq!(
+ approx_f16(results, 4),
+ vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
+ );
+
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
+ .iter()
+ .map(|v| bf16::from_f32(*v))
+ .collect::<Vec<_>>();
+ let last_dim = 6;
+ let results = run_softmax(&v, last_dim, "softmax_bfloat");
+ assert_eq!(
+ approx_bf16(results, 4),
+ vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]
+ );
}
fn run_where_cond<I: Clone, T: Clone>(
@@ -571,7 +679,7 @@ fn run_where_cond<I: Clone, T: Clone>(
options,
);
- let mut output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
+ let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
call_where_cond_strided(
&device,
command_buffer,
@@ -584,7 +692,7 @@ fn run_where_cond<I: Clone, T: Clone>(
(&left_stride, left_offset),
&right,
(&cond_stride, cond_offset),
- &mut output,
+ &output,
)
.unwrap();
command_buffer.commit();
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index eb6424e8..88139af9 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -1,4 +1,7 @@
#include <metal_stdlib>
+#include <metal_math>
+#
+using namespace metal;
METAL_FUNC uint get_strided_index(
uint idx,
@@ -17,10 +20,39 @@ METAL_FUNC uint get_strided_index(
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
template <typename T> METAL_FUNC T neg(T in){ return -in; }
+template <typename T> METAL_FUNC T erf(T in){
+ float x = (float) in;
+ // constants
+ float a1 = 0.254829592;
+ float a2 = -0.284496736;
+ float a3 = 1.421413741;
+ float a4 = -1.453152027;
+ float a5 = 1.061405429;
+ float p = 0.3275911;
+
+ // Save the sign of x
+ int sign = 1;
+ if (x < 0)
+ sign = -1;
+ x = fabs(x);
+
+ // A&S formula 7.1.26
+ float t = 1.0/(1.0 + p*x);
+ float y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x);
+
+ return T(sign*y);
+}
template <typename T> METAL_FUNC T id(T in){ return in; }
+template <typename T> METAL_FUNC T gelu_erf(T x){ return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); }
+template <typename T> METAL_FUNC T gelu(T x){
+ T x_sq = x * x;
+ T x_cube = x_sq * x;
+ T alpha = x + static_cast<T>(0.044715) * x_cube;
+ T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
+ return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
+}
-using namespace metal;
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
@@ -64,8 +96,16 @@ UNARY_OP(sqrt)
UNARY_OP(neg)
UNARY_OP(exp)
UNARY_OP(log)
+UNARY_OP(gelu)
+UNARY_OP(ceil)
+UNARY_OP(floor)
+UNARY_OP(round)
+UNARY_OP(gelu_erf)
+UNARY_OP(erf)
UNARY(id, float, copy_float, copy_float_strided)
UNARY(id, half, copy_half, copy_half_strided)
+UNARY(id, uint8_t, copy_u8, copy_u8_strided)
+UNARY(id, uint32_t, copy_u32, copy_u32_strided)
#if __METAL_VERSION__ >= 310
BFLOAT_UNARY_OP(cos)
@@ -75,6 +115,12 @@ BFLOAT_UNARY_OP(sqrt)
BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp)
BFLOAT_UNARY_OP(log)
+BFLOAT_UNARY_OP(gelu)
+BFLOAT_UNARY_OP(ceil)
+BFLOAT_UNARY_OP(floor)
+BFLOAT_UNARY_OP(round)
+BFLOAT_UNARY_OP(gelu_erf)
+BFLOAT_UNARY_OP(erf)
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
#endif
diff --git a/candle-metal-kernels/examples/affine.rs b/candle-metal-kernels/tmp/affine.rs
index b8005dc0..cd019056 100644
--- a/candle-metal-kernels/examples/affine.rs
+++ b/candle-metal-kernels/tmp/affine.rs
@@ -50,6 +50,7 @@ fn run_affine_bench<T: Clone>(device: &Device, kernels: &Kernels, v: &[T]) {
&device,
command_buffer,
&kernels,
+ "affine_float",
v.len(),
&input,
&mut output,
diff --git a/candle-metal-kernels/examples/binary.rs b/candle-metal-kernels/tmp/binary.rs
index af5a8bdc..af5a8bdc 100644
--- a/candle-metal-kernels/examples/binary.rs
+++ b/candle-metal-kernels/tmp/binary.rs
diff --git a/candle-metal-kernels/examples/cast.rs b/candle-metal-kernels/tmp/cast.rs
index 090f510d..090f510d 100644
--- a/candle-metal-kernels/examples/cast.rs
+++ b/candle-metal-kernels/tmp/cast.rs
diff --git a/candle-metal-kernels/examples/unary.rs b/candle-metal-kernels/tmp/unary.rs
index 7039c098..66cf25c0 100644
--- a/candle-metal-kernels/examples/unary.rs
+++ b/candle-metal-kernels/tmp/unary.rs
@@ -147,7 +147,7 @@ fn run_unary_bench<T: Clone>(
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
- kernel_name.to_string(),
+ kernel_name.0,
v.len(),
iterations,
total_time,
@@ -159,7 +159,7 @@ fn run_unary_bench<T: Clone>(
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
- for kernel_name in strided {
+ for kernel_name in &strided {
let total_time = autoreleasepool(|| {
let command_buffer = command_queue.new_command_buffer();
let start = Instant::now();
@@ -187,7 +187,7 @@ fn run_unary_bench<T: Clone>(
println!(
"{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
type_name::<T>().split("::").last().unwrap(),
- kernel_name.to_string(),
+ kernel_name.0,
v.len(),
iterations,
total_time,
diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml
index d3f43c73..45298907 100644
--- a/candle-nn/Cargo.toml
+++ b/candle-nn/Cargo.toml
@@ -19,6 +19,7 @@ num-traits = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
+candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
[dev-dependencies]
anyhow = { workspace = true }
@@ -29,3 +30,4 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
+metal = ["candle/metal", "dep:candle-metal-kernels"]
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index a0269e59..350bc663 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -201,6 +201,46 @@ impl candle::CustomOp1 for SoftmaxLastDim {
};
Ok((dst, layout.shape().clone()))
}
+
+ #[cfg(feature = "metal")]
+ fn metal_fwd(
+ &self,
+ storage: &candle::MetalStorage,
+ layout: &Layout,
+ ) -> Result<(candle::MetalStorage, Shape)> {
+ use candle::{backend::BackendStorage, DType};
+ let device = storage.device();
+ let command_buffer = device.command_buffer();
+ let kernels = device.kernels();
+ let name = match storage.dtype() {
+ DType::F32 => "softmax_float",
+ DType::F16 => "softmax_half",
+ DType::BF16 => "softmax_bfloat",
+ dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
+ };
+
+ let n = layout.stride().len();
+ if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
+ candle::bail!("Non contiguous softmax-last-dim is not implemented");
+ }
+
+ let last_dim = layout.dims()[layout.shape().rank() - 1];
+ let elem_count = layout.shape().elem_count();
+ let mut output = device.new_buffer(elem_count, storage.dtype());
+ candle_metal_kernels::call_last_softmax(
+ device.metal_device(),
+ &command_buffer,
+ &kernels,
+ name,
+ elem_count,
+ last_dim,
+ storage.buffer(),
+ &mut output,
+ )
+ .unwrap();
+ let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
+ Ok((newstorage, layout.shape().clone()))
+ }
}
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {