summaryrefslogtreecommitdiff
path: root/candle-core/src/metal_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/metal_backend.rs')
-rw-r--r--candle-core/src/metal_backend.rs702
1 files changed, 475 insertions, 227 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index 0b72f080..12f56d50 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -4,11 +4,13 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels;
use candle_metal_kernels::Kernels;
-use core::mem;
-use half::{bf16, f16};
+use half::f16;
use metal;
-use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger};
-use std::sync::Arc;
+use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
+use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
+use std::collections::HashMap;
+use std::path::Path;
+use std::sync::{Arc, RwLock};
/// Metal related errors
#[derive(thiserror::Error, Debug)]
@@ -36,7 +38,9 @@ impl From<String> for MetalError {
pub struct MetalDevice {
device: metal::Device,
command_queue: metal::CommandQueue,
+ command_buffer: Arc<RwLock<metal::CommandBuffer>>,
kernels: Arc<candle_metal_kernels::Kernels>,
+ buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>,
}
impl std::fmt::Debug for MetalDevice {
@@ -58,10 +62,48 @@ impl MetalDevice {
self.registry_id()
}
+ pub fn metal_device(&self) -> &metal::Device {
+ &self.device
+ }
+
pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue
}
+ pub fn command_buffer(&self) -> std::sync::RwLockReadGuard<CommandBuffer> {
+ self.command_buffer.try_read().unwrap()
+ }
+
+ pub fn commit(&self) {
+ let mut old = self.command_buffer.try_write().unwrap();
+ match old.status() {
+ metal::MTLCommandBufferStatus::NotEnqueued
+ | metal::MTLCommandBufferStatus::Enqueued => {
+ old.commit();
+ let command_buffer = self.command_queue.new_command_buffer().to_owned();
+ *old = command_buffer;
+ }
+ _ => {}
+ }
+ }
+
+ pub fn wait_until_completed(&self) {
+ let mut old = self.command_buffer.try_write().unwrap();
+ match old.status() {
+ metal::MTLCommandBufferStatus::NotEnqueued
+ | metal::MTLCommandBufferStatus::Enqueued => {
+ old.commit();
+ old.wait_until_completed();
+ }
+ metal::MTLCommandBufferStatus::Committed | metal::MTLCommandBufferStatus::Scheduled => {
+ old.wait_until_completed();
+ }
+ _ => {}
+ }
+ let command_buffer = self.command_queue.new_command_buffer().to_owned();
+ *old = command_buffer;
+ }
+
pub fn kernels(&self) -> &Kernels {
&self.kernels
}
@@ -70,16 +112,107 @@ impl MetalDevice {
&self.device
}
- pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
+ pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Arc<Buffer> {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
- self.device
- .new_buffer(size, MTLResourceOptions::StorageModeManaged)
+ self._new_buffer(size, MTLResourceOptions::StorageModePrivate)
+ }
+
+ fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions) -> Arc<Buffer> {
+ let mut buffers = self.buffers.try_write().unwrap();
+ let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
+
+ for sub in &mut *subbuffers {
+ if Arc::strong_count(sub) == 1 {
+ return sub.clone();
+ }
+ }
+ let new_buffer = self.device.new_buffer(size as NSUInteger, option);
+ let new_buffer = Arc::new(new_buffer);
+ subbuffers.push(new_buffer.clone());
+ new_buffer
+ }
+
+ pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
+ self._new_buffer(size, MTLResourceOptions::StorageModeManaged)
+ }
+
+ pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
+ let size = core::mem::size_of_val(data) as NSUInteger;
+ let tmp = self.device.new_buffer_with_data(
+ data.as_ptr() as *const core::ffi::c_void,
+ size,
+ metal::MTLResourceOptions::StorageModeManaged,
+ );
+ let real = self._new_buffer(size, metal::MTLResourceOptions::StorageModePrivate);
+ {
+ let command = self.command_buffer();
+ let blit = command.new_blit_command_encoder();
+ blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
+ blit.end_encoding();
+ }
+ // This is necessary, for mmaped safetensors
+ // Because of the unsafe slice cast we're doing.
+ // The slice might not live long enough for metal
+ // To actually fill the GPU buffer.
+ // Putting this wait forces the GPU buffer to be filled
+ // with the actual data allowing the CPU storage todo
+ // deallocate properly.
+ self.wait_until_completed();
+ real
+ }
+
+ pub fn new_matrix(
+ &self,
+ (b, m, n): (NSUInteger, NSUInteger, NSUInteger),
+ size: NSUInteger,
+ type_id: u32,
+ dtype: DType,
+ ) -> Result<(Matrix, Arc<Buffer>)> {
+ let elem_count = (b * m * n) as usize;
+ let out_buffer = self.new_buffer(elem_count, dtype);
+
+ let result_descriptor =
+ MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id);
+ let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor)
+ .ok_or_else(|| {
+ MetalError::from("Failed to create matrix multiplication kernel".to_string())
+ })?;
+ Ok((result_matrix, out_buffer))
+ }
+
+ pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
+ let capture = metal::CaptureManager::shared();
+ let descriptor = metal::CaptureDescriptor::new();
+ descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
+ descriptor.set_capture_device(&self);
+ descriptor.set_output_url(path);
+
+ capture
+ .start_capture(&descriptor)
+ .map_err(MetalError::from)?;
+ Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MetalStorage {
- buffer: metal::Buffer,
+ buffer: Arc<metal::Buffer>,
+ matrices: Arc<
+ RwLock<
+ HashMap<
+ (
+ NSUInteger,
+ NSUInteger,
+ NSUInteger,
+ bool,
+ NSUInteger,
+ NSUInteger,
+ u32,
+ ),
+ Matrix,
+ >,
+ >,
+ >,
device: MetalDevice,
dtype: DType,
}
@@ -108,14 +241,23 @@ impl BackendStorage for MetalStorage {
self.dtype
);
}
+
+ let buffer = self.device.new_buffer_managed(self.buffer.length());
+ let command_buffer = self.device.command_buffer();
+ let blit = command_buffer.new_blit_command_encoder();
+ blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
+ blit.end_encoding();
+ drop(command_buffer);
+ self.device.wait_until_completed();
+
match self.dtype {
- DType::U8 => Ok(CpuStorage::U8(self.buffer.read_to_vec(length / size))),
- DType::U32 => Ok(CpuStorage::U32(self.buffer.read_to_vec(length / size))),
- DType::I64 => Ok(CpuStorage::I64(self.buffer.read_to_vec(length / size))),
- DType::F16 => Ok(CpuStorage::F16(self.buffer.read_to_vec(length / size))),
- DType::BF16 => Ok(CpuStorage::BF16(self.buffer.read_to_vec(length / size))),
- DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(length / size))),
- DType::F64 => Ok(CpuStorage::F64(self.buffer.read_to_vec(length / size))),
+ DType::U8 => Ok(CpuStorage::U8(buffer.read_to_vec(length / size))),
+ DType::U32 => Ok(CpuStorage::U32(buffer.read_to_vec(length / size))),
+ DType::I64 => Ok(CpuStorage::I64(buffer.read_to_vec(length / size))),
+ DType::F16 => Ok(CpuStorage::F16(buffer.read_to_vec(length / size))),
+ DType::BF16 => Ok(CpuStorage::BF16(buffer.read_to_vec(length / size))),
+ DType::F32 => Ok(CpuStorage::F32(buffer.read_to_vec(length / size))),
+ DType::F64 => Ok(CpuStorage::F64(buffer.read_to_vec(length / size))),
}
}
@@ -126,30 +268,48 @@ impl BackendStorage for MetalStorage {
let el = shape.elem_count();
let dtype = self.dtype;
- if layout.is_contiguous() || layout.start_offset() != 0 || dtype != DType::F32 {
- crate::bail!("Not contiguous, non-f32 affine is not implemented yet.");
+ let buffer = device.new_buffer(el, self.dtype);
+ let command_buffer = self.device.command_buffer();
+ if layout.is_contiguous() && layout.start_offset() == 0 {
+ let name = match self.dtype {
+ DType::F32 => "affine_float",
+ DType::F16 => "affine_half",
+ dtype => crate::bail!("Affine {dtype:?}"),
+ };
+ candle_metal_kernels::call_affine(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ name,
+ el,
+ &self.buffer,
+ &buffer,
+ mul as f32,
+ add as f32,
+ )
+ .map_err(MetalError::from)?;
+ } else {
+ let name = match self.dtype {
+ DType::F32 => "affine_float_strided",
+ DType::F16 => "affine_half_strided",
+ dtype => crate::bail!("Affine {dtype:?}"),
+ };
+ candle_metal_kernels::call_affine_strided(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ name,
+ layout.dims(),
+ &self.buffer,
+ layout.stride(),
+ layout.start_offset() * dtype.size_in_bytes(),
+ &buffer,
+ mul as f32,
+ add as f32,
+ )
+ .map_err(MetalError::from)?;
}
-
- let mut buffer = device.new_buffer(el, self.dtype);
- let command_buffer = self.device.command_queue.new_command_buffer();
- candle_metal_kernels::call_affine(
- &device.device,
- &command_buffer,
- &device.kernels,
- el,
- &self.buffer,
- &mut buffer,
- mul as f32,
- add as f32,
- )
- .map_err(MetalError::from)?;
- command_buffer.commit();
- command_buffer.wait_until_completed();
- return Ok(Self {
- buffer,
- device: device.clone(),
- dtype,
- });
+ Ok(Self::new(buffer, device.clone(), dtype))
}
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
@@ -163,11 +323,11 @@ impl BackendStorage for MetalStorage {
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
if !(sum_dims.len() == 1
&& sum_dims[0] == layout.shape().rank() - 1
- && layout.is_contiguous()
- && layout.start_offset() == 0)
+ && layout.stride()[sum_dims[0]] == 1)
{
- crate::bail!("Non contiguous reduce op not supported yet");
+ crate::bail!("Non last dim reduce op not supported yet");
}
+
let device = self.device.clone();
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
@@ -202,8 +362,11 @@ impl BackendStorage for MetalStorage {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let dtype = if return_index { DType::U32 } else { self.dtype };
- let mut buffer = device.new_buffer(dst_el, dtype);
- let command_buffer = self.device.command_queue.new_command_buffer();
+ if dtype == DType::U32 {
+ crate::bail!("Implement return index reduce op");
+ }
+ let buffer = device.new_buffer(dst_el, dtype);
+ let command_buffer = self.device.command_buffer();
candle_metal_kernels::call_reduce_contiguous(
&device.device,
&command_buffer,
@@ -212,17 +375,12 @@ impl BackendStorage for MetalStorage {
src_el,
dst_el,
&self.buffer,
- &mut buffer,
+ layout.start_offset() * self.dtype.size_in_bytes(),
+ &buffer,
)
.map_err(MetalError::from)?;
- command_buffer.commit();
- command_buffer.wait_until_completed();
- Ok(Self {
- buffer,
- device,
- dtype,
- })
+ Ok(Self::new(buffer, device, dtype))
}
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
@@ -233,11 +391,15 @@ impl BackendStorage for MetalStorage {
let device = self.device();
let shape = layout.shape();
let el_count = shape.elem_count();
- let mut buffer = device.new_buffer(el_count, dtype);
- let command_buffer = device.command_queue.new_command_buffer();
+ let buffer = device.new_buffer(el_count, dtype);
+ let command_buffer = device.command_buffer();
if layout.is_contiguous() {
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32",
+ (DType::U32, DType::U8) => "cast_u32_u8",
+ (DType::U8, DType::U32) => "cast_u8_u32",
+ (DType::F32, DType::F16) => "cast_f32_f16",
+ (DType::F16, DType::F32) => "cast_f16_f32",
(left, right) => crate::bail!("to dtype {left:?} - {right:?}"),
};
candle_metal_kernels::call_cast_contiguous(
@@ -247,24 +409,34 @@ impl BackendStorage for MetalStorage {
kernel_name,
el_count,
&self.buffer,
- &mut buffer,
+ layout.start_offset() * self.dtype.size_in_bytes(),
+ &buffer,
)
.map_err(MetalError::from)?;
} else {
- crate::bail!(
- "TODO Implement the kernel calling cast {:?}-{:?}",
- self.dtype,
- dtype
- );
+ let kernel_name = match (self.dtype, dtype) {
+ (DType::U32, DType::F32) => "cast_u32_f32_strided",
+ (DType::U32, DType::U8) => "cast_u32_u8_strided",
+ (DType::U8, DType::U32) => "cast_u8_u32_strided",
+ (DType::F32, DType::F16) => "cast_f32_f16_strided",
+ (DType::F16, DType::F32) => "cast_f16_f32_strided",
+ (left, right) => crate::bail!("to dtype {left:?} - {right:?}"),
+ };
+ candle_metal_kernels::call_cast_strided(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ kernel_name,
+ layout.dims(),
+ &self.buffer,
+ layout.stride(),
+ layout.start_offset() * self.dtype.size_in_bytes(),
+ &buffer,
+ )
+ .map_err(MetalError::from)?;
}
- command_buffer.commit();
- command_buffer.wait_until_completed();
- Ok(Self {
- buffer,
- device: device.clone(),
- dtype,
- })
+ Ok(Self::new(buffer, device.clone(), dtype))
}
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
@@ -272,8 +444,8 @@ impl BackendStorage for MetalStorage {
let dtype = self.dtype;
let shape = layout.shape();
let el_count = shape.elem_count();
- let mut buffer = device.new_buffer(el_count, dtype);
- let command_buffer = device.command_queue.new_command_buffer();
+ let buffer = device.new_buffer(el_count, dtype);
+ let command_buffer = device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 {
use candle_metal_kernels::unary::contiguous;
@@ -285,6 +457,25 @@ impl BackendStorage for MetalStorage {
("uneg", DType::F32) => contiguous::neg::FLOAT,
("uexp", DType::F32) => contiguous::exp::FLOAT,
("ulog", DType::F32) => contiguous::log::FLOAT,
+ ("ugelu", DType::F32) => contiguous::gelu::FLOAT,
+ ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
+ ("uerf", DType::F32) => contiguous::erf::FLOAT,
+ ("uceil", DType::F32) => contiguous::ceil::FLOAT,
+ ("ufloor", DType::F32) => contiguous::floor::FLOAT,
+ ("uround", DType::F32) => contiguous::round::FLOAT,
+ ("ucos", DType::F16) => contiguous::cos::HALF,
+ ("usin", DType::F16) => contiguous::sin::HALF,
+ ("usqr", DType::F16) => contiguous::sqr::HALF,
+ ("usqrt", DType::F16) => contiguous::sqrt::HALF,
+ ("uneg", DType::F16) => contiguous::neg::HALF,
+ ("uexp", DType::F16) => contiguous::exp::HALF,
+ ("ulog", DType::F16) => contiguous::log::HALF,
+ ("ugelu", DType::F16) => contiguous::gelu::HALF,
+ ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
+ ("uerf", DType::F16) => contiguous::erf::HALF,
+ ("uceil", DType::F16) => contiguous::ceil::HALF,
+ ("ufloor", DType::F16) => contiguous::floor::HALF,
+ ("uround", DType::F16) => contiguous::round::HALF,
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_unary_contiguous(
@@ -294,20 +485,58 @@ impl BackendStorage for MetalStorage {
kernel_name,
el_count,
&self.buffer,
- &mut buffer,
+ &buffer,
)
.map_err(MetalError::from)?;
} else {
- crate::bail!("TODO Implement the kernel calling {}", B::KERNEL);
+ use candle_metal_kernels::unary::strided;
+ let kernel_name = match (B::KERNEL, dtype) {
+ ("ucos", DType::F32) => strided::cos::FLOAT,
+ ("usin", DType::F32) => strided::sin::FLOAT,
+ ("usqr", DType::F32) => strided::sqr::FLOAT,
+ ("usqrt", DType::F32) => strided::sqrt::FLOAT,
+ ("uneg", DType::F32) => strided::neg::FLOAT,
+ ("uexp", DType::F32) => strided::exp::FLOAT,
+ ("ulog", DType::F32) => strided::log::FLOAT,
+ ("ugelu", DType::F32) => strided::gelu::FLOAT,
+ ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
+ ("uerf", DType::F32) => strided::erf::FLOAT,
+ ("uceil", DType::F32) => strided::ceil::FLOAT,
+ ("ufloor", DType::F32) => strided::floor::FLOAT,
+ ("uround", DType::F32) => strided::round::FLOAT,
+ ("ucos", DType::F16) => strided::cos::HALF,
+ ("usin", DType::F16) => strided::sin::HALF,
+ ("usqr", DType::F16) => strided::sqr::HALF,
+ ("usqrt", DType::F16) => strided::sqrt::HALF,
+ ("uneg", DType::F16) => strided::neg::HALF,
+ ("uexp", DType::F16) => strided::exp::HALF,
+ ("ulog", DType::F16) => strided::log::HALF,
+ ("ugelu", DType::F16) => strided::gelu::HALF,
+ ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
+ ("uerf", DType::F16) => strided::erf::HALF,
+ ("uceil", DType::F16) => strided::ceil::HALF,
+ ("ufloor", DType::F16) => strided::floor::HALF,
+ ("uround", DType::F16) => strided::round::HALF,
+ (name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
+ };
+ candle_metal_kernels::call_unary_strided(
+ &device.device,
+ &command_buffer,
+ &device.kernels,
+ kernel_name,
+ layout.dims(),
+ &self.buffer,
+ layout.stride(),
+ layout.start_offset() * self.dtype.size_in_bytes(),
+ &buffer,
+ 0,
+ )
+ .map_err(MetalError::from)?;
}
- command_buffer.commit();
- command_buffer.wait_until_completed();
-
- Ok(Self {
- buffer,
- device: device.clone(),
- dtype,
- })
+ command_buffer.set_label("unary");
+ drop(command_buffer);
+ self.device.commit();
+ Ok(Self::new(buffer, device.clone(), dtype))
}
fn binary_impl<B: BinaryOpT>(
@@ -320,8 +549,8 @@ impl BackendStorage for MetalStorage {
let dtype = self.dtype;
let shape = lhs_l.shape();
let el_count = shape.elem_count();
- let mut buffer = device.new_buffer(el_count, dtype);
- let command_buffer = device.command_queue.new_command_buffer();
+ let buffer = device.new_buffer(el_count, dtype);
+ let command_buffer = device.command_buffer();
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
{
@@ -336,6 +565,14 @@ impl BackendStorage for MetalStorage {
("bmul", DType::F32) => contiguous::mul::FLOAT,
("div", DType::F32) => contiguous::div::FLOAT,
("bdiv", DType::F32) => contiguous::div::FLOAT,
+ ("add", DType::F16) => contiguous::add::HALF,
+ ("badd", DType::F16) => contiguous::add::HALF,
+ ("sub", DType::F16) => contiguous::sub::HALF,
+ ("bsub", DType::F16) => contiguous::sub::HALF,
+ ("mul", DType::F16) => contiguous::mul::HALF,
+ ("bmul", DType::F16) => contiguous::mul::HALF,
+ ("div", DType::F16) => contiguous::div::HALF,
+ ("bdiv", DType::F16) => contiguous::div::HALF,
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_contiguous(
@@ -346,7 +583,7 @@ impl BackendStorage for MetalStorage {
el_count,
&self.buffer,
&rhs.buffer,
- &mut buffer,
+ &buffer,
)
.map_err(MetalError::from)?;
} else {
@@ -357,6 +594,10 @@ impl BackendStorage for MetalStorage {
("bsub", DType::F32) => strided::sub::FLOAT,
("bmul", DType::F32) => strided::mul::FLOAT,
("bdiv", DType::F32) => strided::div::FLOAT,
+ ("badd", DType::F16) => strided::add::HALF,
+ ("bsub", DType::F16) => strided::sub::HALF,
+ ("bmul", DType::F16) => strided::mul::HALF,
+ ("bdiv", DType::F16) => strided::div::HALF,
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_strided(
@@ -366,23 +607,19 @@ impl BackendStorage for MetalStorage {
kernel_name,
lhs_l.dims(),
&self.buffer,
- &lhs_l.stride(),
+ lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&rhs.buffer,
- &rhs_l.stride(),
+ rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
- &mut buffer,
+ &buffer,
)
.map_err(MetalError::from)?;
}
- command_buffer.commit();
- command_buffer.wait_until_completed();
-
- Ok(Self {
- buffer,
- device: device.clone(),
- dtype,
- })
+ command_buffer.set_label("binary");
+ drop(command_buffer);
+ self.device.commit();
+ Ok(Self::new(buffer, device.clone(), dtype))
}
fn where_cond(
@@ -398,14 +635,22 @@ impl BackendStorage for MetalStorage {
let dims = shape.dims();
let el = shape.elem_count();
let dtype = t.dtype;
- let mut buffer = self.device.new_buffer(el, dtype);
- let command_buffer = self.device.command_queue.new_command_buffer();
+ let buffer = self.device.new_buffer(el, dtype);
+ let command_buffer = self.device.command_buffer();
+ if t.dtype() != f.dtype() {
+ crate::bail!("Invalid ternary different dtypes for values");
+ }
+ let name = match (self.dtype, t.dtype()) {
+ (DType::U8, DType::F32) => "where_u8_f32",
+ (DType::U8, DType::F16) => "where_u8_f16",
+ (left, right) => crate::bail!("Ternary {left:?} - {right:?} not implemented"),
+ };
candle_metal_kernels::call_where_cond_strided(
&device.device,
&command_buffer,
&device.kernels,
- "where_u8_f32",
- &dims,
+ name,
+ dims,
&self.buffer,
(
layout.stride(),
@@ -415,16 +660,10 @@ impl BackendStorage for MetalStorage {
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
&f.buffer,
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
- &mut buffer,
+ &buffer,
)
.map_err(MetalError::from)?;
- command_buffer.commit();
- command_buffer.wait_until_completed();
- Ok(Self {
- buffer,
- device,
- dtype,
- })
+ Ok(Self::new(buffer, device, dtype))
}
fn conv1d(
@@ -513,12 +752,13 @@ impl BackendStorage for MetalStorage {
let dst_el = ids_el * left_size * right_size;
let dtype = self.dtype;
let device = self.device();
- let mut buffer = device.new_buffer(dst_el, dtype);
+ let buffer = device.new_buffer(dst_el, dtype);
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "is_u32_f32",
+ (DType::U32, DType::F16) => "is_u32_f16",
(left, right) => crate::bail!("index select metal {left:?} {right:?}"),
};
- let command_buffer = self.device.command_queue.new_command_buffer();
+ let command_buffer = self.device.command_buffer();
candle_metal_kernels::call_index_select(
&device.device,
&command_buffer,
@@ -529,16 +769,10 @@ impl BackendStorage for MetalStorage {
dim,
&self.buffer,
&ids.buffer,
- &mut buffer,
+ &buffer,
)
.map_err(MetalError::from)?;
- command_buffer.commit();
- command_buffer.wait_until_completed();
- Ok(Self {
- buffer,
- device: device.clone(),
- dtype,
- })
+ Ok(Self::new(buffer, device.clone(), dtype))
}
fn index_add(
@@ -561,11 +795,18 @@ impl BackendStorage for MetalStorage {
rhs_l: &Layout,
) -> Result<Self> {
// Create descriptors
- use metal::mps::matrix::*;
- let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32;
- let size = core::mem::size_of::<f32>() as NSUInteger;
- let elem_count = b * m * n;
+ let (type_id, size) = match self.dtype {
+ DType::F32 => (
+ metal::mps::MPS_FLOATBIT_ENCODING | 32,
+ core::mem::size_of::<f32>() as NSUInteger,
+ ),
+ DType::F16 => (
+ metal::mps::MPS_FLOATBIT_ENCODING | 16,
+ core::mem::size_of::<f16>() as NSUInteger,
+ ),
+ dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
+ };
let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride();
@@ -596,39 +837,30 @@ impl BackendStorage for MetalStorage {
mnk: (m, n, k),
})?
};
-
let b = b as NSUInteger;
let m = m as NSUInteger;
let n = n as NSUInteger;
let k = k as NSUInteger;
- let left_descriptor = if transpose_left {
- MatrixDescriptor::init_single(k, m, m * size, type_id)
- } else {
- MatrixDescriptor::init_single(m, k, k * size, type_id)
- };
- let right_descriptor = if transpose_right {
- MatrixDescriptor::init_single(n, k, k * size, type_id)
- } else {
- MatrixDescriptor::init_single(k, n, n * size, type_id)
- };
- let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
-
- // Create matrix objects
- let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, 0, &left_descriptor)
- .ok_or_else(|| {
- MetalError::from("Failed to create matrix multiplication kernel".to_string())
- })?;
- let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, 0, &right_descriptor)
- .ok_or_else(|| {
- MetalError::from("Failed to create matrix multiplication kernel".to_string())
- })?;
+ let left_matrix = self.matrix(
+ (b, m, k),
+ transpose_left,
+ size,
+ lhs_l.start_offset() as NSUInteger * size,
+ type_id,
+ )?;
+ let right_matrix = rhs.matrix(
+ (b, k, n),
+ transpose_right,
+ size,
+ rhs_l.start_offset() as NSUInteger * size,
+ type_id,
+ )?;
+ let (result_matrix, out_buffer) =
+ self.device
+ .new_matrix((b, m, n), size, type_id, self.dtype)?;
- let out_buffer = self.device.new_buffer(elem_count, self.dtype);
- let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor)
- .ok_or_else(|| {
- MetalError::from("Failed to create matrix multiplication kernel".to_string())
- })?;
+ let command_buffer = self.device.command_buffer();
let alpha = 1.0f64;
let beta = 0.0f64;
@@ -647,70 +879,112 @@ impl BackendStorage for MetalStorage {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
- matrix_multiplication.set_batch_size(b);
-
// Encode kernel to command buffer
- let command_buffer = self.device.command_queue.new_command_buffer();
matrix_multiplication.encode_to_command_buffer(
- command_buffer,
+ &command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
- command_buffer.commit();
- command_buffer.wait_until_completed();
+ command_buffer.set_label("matmul");
+ drop(command_buffer);
+ self.device.commit();
- Ok(Self {
- buffer: out_buffer,
- device: self.device.clone(),
- dtype: self.dtype(),
- })
+ Ok(Self::new(out_buffer, self.device.clone(), self.dtype()))
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
- let src_shape = src_l.shape();
- let el_count = src_shape.elem_count();
- if el_count == 0 {
- return Ok(());
+ let command_buffer = self.device.command_buffer();
+ if src_l.is_contiguous() && self.dtype == dst.dtype() {
+ command_buffer.set_label("copy_contiguous");
+ let blit = command_buffer.new_blit_command_encoder();
+ let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
+ let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
+ blit.copy_from_buffer(
+ &self.buffer,
+ src_offset,
+ dst.buffer(),
+ dst_offset,
+ self.buffer.length() - src_offset,
+ );
+ blit.end_encoding();
+ } else {
+ let src_shape = src_l.shape();
+ let el_count = src_shape.elem_count();
+ if el_count == 0 {
+ return Ok(());
+ }
+ let kernel_name = match self.dtype {
+ DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
+ DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
+ DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
+ DType::U32 => candle_metal_kernels::unary::strided::copy::U32,
+ DType::U8 => candle_metal_kernels::unary::strided::copy::U8,
+ dtype => crate::bail!("copy_strided not implemented for {dtype:?}"),
+ };
+ candle_metal_kernels::call_unary_strided(
+ &self.device.device,
+ &command_buffer,
+ &self.device.kernels,
+ kernel_name,
+ src_l.dims(),
+ &self.buffer,
+ src_l.stride(),
+ src_l.start_offset() * self.dtype.size_in_bytes(),
+ &dst.buffer,
+ dst_offset * dst.dtype.size_in_bytes(),
+ )
+ .map_err(MetalError::from)?;
+ command_buffer.set_label("copy_strided");
}
- let command_buffer = self.device.command_queue.new_command_buffer();
- let kernel_name = match self.dtype {
- DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
- DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
- DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
- dtype => crate::bail!("copy_strided not implemented for {dtype:?}"),
- };
- candle_metal_kernels::call_unary_strided(
- &self.device.device,
- &command_buffer,
- &self.device.kernels,
- kernel_name,
- src_l.dims(),
- &self.buffer,
- &src_l.stride(),
- src_l.start_offset() * self.dtype.size_in_bytes(),
- &mut dst.buffer,
- dst_offset,
- )
- .map_err(MetalError::from)?;
- command_buffer.commit();
- command_buffer.wait_until_completed();
+ drop(command_buffer);
+ self.device.commit();
Ok(())
}
}
impl MetalStorage {
- pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self {
+ pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self {
+ let matrices = Arc::new(RwLock::new(HashMap::new()));
Self {
buffer,
device,
dtype,
+ matrices,
}
}
pub fn buffer(&self) -> &Buffer {
&self.buffer
}
+
+ fn matrix(
+ &self,
+ (b, m, n): (NSUInteger, NSUInteger, NSUInteger),
+ transpose: bool,
+ size: NSUInteger,
+ offset: NSUInteger,
+ type_id: u32,
+ ) -> Result<Matrix> {
+ let key = (b, m, n, transpose, size, offset, type_id);
+
+ let mut matrices = self.matrices.try_write().unwrap();
+ if let Some(matrix) = matrices.get(&key) {
+ Ok(matrix.clone())
+ } else {
+ let descriptor = if transpose {
+ MatrixDescriptor::init_multiple(n, m, b, m * size, m * n * size, type_id)
+ } else {
+ MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id)
+ };
+ let matrix = Matrix::init_with_buffer_descriptor(&self.buffer, offset, &descriptor)
+ .ok_or_else(|| {
+ MetalError::from("Failed to create matrix multiplication kernel".to_string())
+ })?;
+ matrices.insert(key, matrix.clone());
+ Ok(matrix)
+ }
+ }
}
impl BackendDevice for MetalDevice {
@@ -720,10 +994,14 @@ impl BackendDevice for MetalDevice {
let device = metal::Device::all().swap_remove(ordinal);
let command_queue = device.new_command_queue();
+ let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
let kernels = Arc::new(Kernels::new());
+ let buffers = Arc::new(RwLock::new(HashMap::new()));
Ok(Self {
device,
command_queue,
+ command_buffer,
+ buffers,
kernels,
})
}
@@ -743,9 +1021,8 @@ impl BackendDevice for MetalDevice {
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
- // TODO Is there a faster way ?
- let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?;
- self.storage_from_cpu_storage(&cpu_storage)
+ let buffer = self.new_buffer(shape.elem_count(), dtype);
+ Ok(MetalStorage::new(buffer, self.clone(), dtype))
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
@@ -755,49 +1032,20 @@ impl BackendDevice for MetalDevice {
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
- let option = metal::MTLResourceOptions::StorageModeManaged;
let buffer = match storage {
- CpuStorage::U8(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<u8>()) as NSUInteger,
- option,
- ),
- CpuStorage::U32(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<u32>()) as NSUInteger,
- option,
- ),
- CpuStorage::I64(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<i64>()) as NSUInteger,
- option,
- ),
- CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<bf16>()) as NSUInteger,
- option,
- ),
- CpuStorage::F16(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<f16>()) as NSUInteger,
- option,
- ),
- CpuStorage::F32(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<f32>()) as NSUInteger,
- option,
- ),
- CpuStorage::F64(storage) => self.device.new_buffer_with_data(
- storage.as_ptr() as *const core::ffi::c_void,
- (storage.len() * mem::size_of::<f64>()) as NSUInteger,
- option,
- ),
+ CpuStorage::U8(storage) => self.new_buffer_with_data(storage),
+ CpuStorage::U32(storage) => self.new_buffer_with_data(storage),
+ CpuStorage::I64(storage) => self.new_buffer_with_data(storage),
+ CpuStorage::BF16(storage) => self.new_buffer_with_data(storage),
+ CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
+ CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
+ CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
};
- Ok(Self::Storage {
- buffer,
- device: self.clone(),
- dtype: storage.dtype(),
- })
+ Ok(Self::Storage::new(
+ buffer.into(),
+ self.clone(),
+ storage.dtype(),
+ ))
}
fn rand_uniform(