summaryrefslogtreecommitdiff
path: root/candle-core/src/metal_backend/mod.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-07 22:37:53 +0200
committerGitHub <noreply@github.com>2024-04-07 22:37:53 +0200
commitc5fe4a7f8983ae7c9641fa923f26ef60538aef06 (patch)
tree12ad3e2445577fc77a5f9879686d554aea943a0d /candle-core/src/metal_backend/mod.rs
parent7f354473cf495db4554e08f84be44ed498f1aa5e (diff)
downloadcandle-c5fe4a7f8983ae7c9641fa923f26ef60538aef06.tar.gz
candle-c5fe4a7f8983ae7c9641fa923f26ef60538aef06.tar.bz2
candle-c5fe4a7f8983ae7c9641fa923f26ef60538aef06.zip
Rework the buffer offset logic for metal kernels (#2028)
* Move the metal kernels utils in a separate module. * Use the BufferOffset for unary ops. * Fix clippy lints. * Use the new BufferOffset. * Adapt the binary ops. * Affine. * More ops (powf, elu, cast).
Diffstat (limited to 'candle-core/src/metal_backend/mod.rs')
-rw-r--r--candle-core/src/metal_backend/mod.rs82
1 files changed, 43 insertions, 39 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs
index 0e058b45..4adcda05 100644
--- a/candle-core/src/metal_backend/mod.rs
+++ b/candle-core/src/metal_backend/mod.rs
@@ -2,8 +2,7 @@ use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
-use candle_metal_kernels::CallConvTranspose2dCfg;
-use candle_metal_kernels::Kernels;
+use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
use metal::{Buffer, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
@@ -12,6 +11,12 @@ use std::sync::{Arc, Mutex, RwLock, TryLockError};
mod device;
pub use device::{DeviceId, MetalDevice};
+fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
+ BufferOffset {
+ buffer,
+ offset_in_bytes: l.start_offset() * dtype.size_in_bytes(),
+ }
+}
/// Simple way to catch lock error without
/// depending on T
#[derive(thiserror::Error, Debug)]
@@ -102,7 +107,8 @@ impl BackendStorage for MetalStorage {
let buffer = device.new_buffer(el, self.dtype, "affine")?;
let command_buffer = self.device.command_buffer()?;
- if layout.is_contiguous() && layout.start_offset() == 0 {
+ let src = buffer_o(&self.buffer, layout, dtype);
+ if layout.is_contiguous() {
let name = match self.dtype {
DType::F32 => "affine_f32",
DType::F16 => "affine_f16",
@@ -115,7 +121,7 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
el,
- &self.buffer,
+ src,
&buffer,
mul as f32,
add as f32,
@@ -134,9 +140,8 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
layout.dims(),
- &self.buffer,
+ src,
layout.stride(),
- layout.start_offset() * dtype.size_in_bytes(),
&buffer,
mul as f32,
add as f32,
@@ -155,7 +160,8 @@ impl BackendStorage for MetalStorage {
let buffer = device.new_buffer(el, self.dtype, "powf")?;
let command_buffer = self.device.command_buffer()?;
- if layout.is_contiguous() && layout.start_offset() == 0 {
+ let src = buffer_o(&self.buffer, layout, dtype);
+ if layout.is_contiguous() {
let name = match self.dtype {
DType::F32 => "powf_f32",
DType::F16 => "powf_f16",
@@ -168,7 +174,7 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
el,
- &self.buffer,
+ src,
&buffer,
pow as f32,
)
@@ -186,9 +192,8 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
layout.dims(),
- &self.buffer,
+ src,
layout.stride(),
- layout.start_offset() * dtype.size_in_bytes(),
&buffer,
pow as f32,
)
@@ -206,7 +211,8 @@ impl BackendStorage for MetalStorage {
let buffer = device.new_buffer(el, self.dtype, "elu")?;
let command_buffer = self.device.command_buffer()?;
- if layout.is_contiguous() && layout.start_offset() == 0 {
+ let src = buffer_o(&self.buffer, layout, self.dtype);
+ if layout.is_contiguous() {
let name = match self.dtype {
DType::F32 => "elu_f32",
DType::F16 => "elu_f16",
@@ -219,7 +225,7 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
el,
- &self.buffer,
+ src,
&buffer,
alpha as f32,
)
@@ -237,9 +243,8 @@ impl BackendStorage for MetalStorage {
&device.kernels,
name,
layout.dims(),
- &self.buffer,
+ src,
layout.stride(),
- layout.start_offset() * dtype.size_in_bytes(),
&buffer,
alpha as f32,
)
@@ -344,7 +349,8 @@ impl BackendStorage for MetalStorage {
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, "todtype")?;
let command_buffer = device.command_buffer()?;
- if layout.is_contiguous() && layout.start_offset() == 0 {
+ let src = buffer_o(&self.buffer, layout, self.dtype);
+ if layout.is_contiguous() {
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::BF16) => "cast_u32_bf16",
(DType::U32, DType::F16) => "cast_u32_f16",
@@ -392,8 +398,7 @@ impl BackendStorage for MetalStorage {
&device.kernels,
kernel_name,
el_count,
- &self.buffer,
- layout.start_offset() * self.dtype.size_in_bytes(),
+ src,
&buffer,
)
.map_err(MetalError::from)?;
@@ -420,9 +425,8 @@ impl BackendStorage for MetalStorage {
&device.kernels,
kernel_name,
layout.dims(),
- &self.buffer,
+ src,
layout.stride(),
- layout.start_offset() * self.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
@@ -439,7 +443,8 @@ impl BackendStorage for MetalStorage {
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
let command_buffer = device.command_buffer()?;
command_buffer.set_label(B::KERNEL);
- if layout.is_contiguous() && layout.start_offset() == 0 {
+ let src = buffer_o(&self.buffer, layout, self.dtype);
+ if layout.is_contiguous() {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
@@ -511,7 +516,7 @@ impl BackendStorage for MetalStorage {
&device.kernels,
kernel_name,
el_count,
- &self.buffer,
+ src,
&buffer,
)
.map_err(MetalError::from)?;
@@ -556,17 +561,16 @@ impl BackendStorage for MetalStorage {
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
}
};
+ let dst = BufferOffset::zero_offset(&buffer);
candle_metal_kernels::call_unary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
layout.dims(),
- &self.buffer,
+ src,
layout.stride(),
- layout.start_offset() * self.dtype.size_in_bytes(),
- &buffer,
- 0,
+ dst,
)
.map_err(MetalError::from)?;
}
@@ -1358,17 +1362,20 @@ impl BackendStorage for MetalStorage {
DType::U8 => candle_metal_kernels::unary::strided::copy::U8,
dtype => crate::bail!("Metal copy_strided {dtype:?} not implemented"),
};
+ let src = buffer_o(&self.buffer, src_l, self.dtype);
+ let dst = BufferOffset {
+ buffer: &dst.buffer,
+ offset_in_bytes: dst_offset * dst.dtype.size_in_bytes(),
+ };
candle_metal_kernels::call_unary_strided(
&self.device.device,
&command_buffer,
&self.device.kernels,
kernel_name,
src_l.dims(),
- &self.buffer,
+ src,
src_l.stride(),
- src_l.start_offset() * self.dtype.size_in_bytes(),
- &dst.buffer,
- dst_offset * dst.dtype.size_in_bytes(),
+ dst,
)
.map_err(MetalError::from)?;
command_buffer.set_label("copy_strided");
@@ -1402,10 +1409,9 @@ impl MetalStorage {
let shape = lhs_l.shape();
let el_count = shape.elem_count();
let command_buffer = device.command_buffer()?;
- let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
- && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
- && &op[..1] != "b"
- {
+ let lhs = buffer_o(&self.buffer, lhs_l, self.dtype);
+ let rhs = buffer_o(&rhs.buffer, rhs_l, rhs.dtype);
+ let (buffer, dtype) = if lhs_l.is_contiguous() && rhs_l.is_contiguous() && &op[..1] != "b" {
use candle_metal_kernels::binary::contiguous;
let (kernel_name, dtype) = match (op, self.dtype) {
@@ -1486,8 +1492,8 @@ impl MetalStorage {
&device.kernels,
kernel_name,
el_count,
- &self.buffer,
- &rhs.buffer,
+ lhs,
+ rhs,
&buffer,
)
.map_err(MetalError::from)?;
@@ -1585,12 +1591,10 @@ impl MetalStorage {
&device.kernels,
kernel_name,
lhs_l.dims(),
- &self.buffer,
+ lhs,
lhs_l.stride(),
- lhs_l.start_offset() * self.dtype.size_in_bytes(),
- &rhs.buffer,
+ rhs,
rhs_l.stride(),
- rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;