diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/metal_backend/mod.rs | 82 |
1 files changed, 43 insertions, 39 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 0e058b45..4adcda05 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -2,8 +2,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; -use candle_metal_kernels::CallConvTranspose2dCfg; -use candle_metal_kernels::Kernels; +use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels}; use metal::{Buffer, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; use std::ffi::c_void; @@ -12,6 +11,12 @@ use std::sync::{Arc, Mutex, RwLock, TryLockError}; mod device; pub use device::{DeviceId, MetalDevice}; +fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> { + BufferOffset { + buffer, + offset_in_bytes: l.start_offset() * dtype.size_in_bytes(), + } +} /// Simple way to catch lock error without /// depending on T #[derive(thiserror::Error, Debug)] @@ -102,7 +107,8 @@ impl BackendStorage for MetalStorage { let buffer = device.new_buffer(el, self.dtype, "affine")?; let command_buffer = self.device.command_buffer()?; - if layout.is_contiguous() && layout.start_offset() == 0 { + let src = buffer_o(&self.buffer, layout, dtype); + if layout.is_contiguous() { let name = match self.dtype { DType::F32 => "affine_f32", DType::F16 => "affine_f16", @@ -115,7 +121,7 @@ impl BackendStorage for MetalStorage { &device.kernels, name, el, - &self.buffer, + src, &buffer, mul as f32, add as f32, @@ -134,9 +140,8 @@ impl BackendStorage for MetalStorage { &device.kernels, name, layout.dims(), - &self.buffer, + src, layout.stride(), - layout.start_offset() * dtype.size_in_bytes(), &buffer, mul as f32, add as f32, @@ -155,7 +160,8 @@ impl BackendStorage for MetalStorage { let buffer = device.new_buffer(el, self.dtype, "powf")?; let command_buffer = self.device.command_buffer()?; - if layout.is_contiguous() && layout.start_offset() == 0 { + let src = buffer_o(&self.buffer, layout, dtype); + if layout.is_contiguous() { let name = match self.dtype { DType::F32 => "powf_f32", DType::F16 => "powf_f16", @@ -168,7 +174,7 @@ impl BackendStorage for MetalStorage { &device.kernels, name, el, - &self.buffer, + src, &buffer, pow as f32, ) @@ -186,9 +192,8 @@ impl BackendStorage for MetalStorage { &device.kernels, name, layout.dims(), - &self.buffer, + src, layout.stride(), - layout.start_offset() * dtype.size_in_bytes(), &buffer, pow as f32, ) @@ -206,7 +211,8 @@ impl BackendStorage for MetalStorage { let buffer = device.new_buffer(el, self.dtype, "elu")?; let command_buffer = self.device.command_buffer()?; - if layout.is_contiguous() && layout.start_offset() == 0 { + let src = buffer_o(&self.buffer, layout, self.dtype); + if layout.is_contiguous() { let name = match self.dtype { DType::F32 => "elu_f32", DType::F16 => "elu_f16", @@ -219,7 +225,7 @@ impl BackendStorage for MetalStorage { &device.kernels, name, el, - &self.buffer, + src, &buffer, alpha as f32, ) @@ -237,9 +243,8 @@ impl BackendStorage for MetalStorage { &device.kernels, name, layout.dims(), - &self.buffer, + src, layout.stride(), - layout.start_offset() * dtype.size_in_bytes(), &buffer, alpha as f32, ) @@ -344,7 +349,8 @@ impl BackendStorage for MetalStorage { let el_count = shape.elem_count(); let buffer = device.new_buffer(el_count, dtype, "todtype")?; let command_buffer = device.command_buffer()?; - if layout.is_contiguous() && layout.start_offset() == 0 { + let src = buffer_o(&self.buffer, layout, self.dtype); + if layout.is_contiguous() { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::BF16) => "cast_u32_bf16", (DType::U32, DType::F16) => "cast_u32_f16", @@ -392,8 +398,7 @@ impl BackendStorage for MetalStorage { &device.kernels, kernel_name, el_count, - &self.buffer, - layout.start_offset() * self.dtype.size_in_bytes(), + src, &buffer, ) .map_err(MetalError::from)?; @@ -420,9 +425,8 @@ impl BackendStorage for MetalStorage { &device.kernels, kernel_name, layout.dims(), - &self.buffer, + src, layout.stride(), - layout.start_offset() * self.dtype.size_in_bytes(), &buffer, ) .map_err(MetalError::from)?; @@ -439,7 +443,8 @@ impl BackendStorage for MetalStorage { let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?; let command_buffer = device.command_buffer()?; command_buffer.set_label(B::KERNEL); - if layout.is_contiguous() && layout.start_offset() == 0 { + let src = buffer_o(&self.buffer, layout, self.dtype); + if layout.is_contiguous() { use candle_metal_kernels::unary::contiguous; let kernel_name = match (B::KERNEL, dtype) { @@ -511,7 +516,7 @@ impl BackendStorage for MetalStorage { &device.kernels, kernel_name, el_count, - &self.buffer, + src, &buffer, ) .map_err(MetalError::from)?; @@ -556,17 +561,16 @@ impl BackendStorage for MetalStorage { crate::bail!("Metal strided unary {name} {dtype:?} not implemented") } }; + let dst = BufferOffset::zero_offset(&buffer); candle_metal_kernels::call_unary_strided( &device.device, &command_buffer, &device.kernels, kernel_name, layout.dims(), - &self.buffer, + src, layout.stride(), - layout.start_offset() * self.dtype.size_in_bytes(), - &buffer, - 0, + dst, ) .map_err(MetalError::from)?; } @@ -1358,17 +1362,20 @@ impl BackendStorage for MetalStorage { DType::U8 => candle_metal_kernels::unary::strided::copy::U8, dtype => crate::bail!("Metal copy_strided {dtype:?} not implemented"), }; + let src = buffer_o(&self.buffer, src_l, self.dtype); + let dst = BufferOffset { + buffer: &dst.buffer, + offset_in_bytes: dst_offset * dst.dtype.size_in_bytes(), + }; candle_metal_kernels::call_unary_strided( &self.device.device, &command_buffer, &self.device.kernels, kernel_name, src_l.dims(), - &self.buffer, + src, src_l.stride(), - src_l.start_offset() * self.dtype.size_in_bytes(), - &dst.buffer, - dst_offset * dst.dtype.size_in_bytes(), + dst, ) .map_err(MetalError::from)?; command_buffer.set_label("copy_strided"); @@ -1402,10 +1409,9 @@ impl MetalStorage { let shape = lhs_l.shape(); let el_count = shape.elem_count(); let command_buffer = device.command_buffer()?; - let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) - && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) - && &op[..1] != "b" - { + let lhs = buffer_o(&self.buffer, lhs_l, self.dtype); + let rhs = buffer_o(&rhs.buffer, rhs_l, rhs.dtype); + let (buffer, dtype) = if lhs_l.is_contiguous() && rhs_l.is_contiguous() && &op[..1] != "b" { use candle_metal_kernels::binary::contiguous; let (kernel_name, dtype) = match (op, self.dtype) { @@ -1486,8 +1492,8 @@ impl MetalStorage { &device.kernels, kernel_name, el_count, - &self.buffer, - &rhs.buffer, + lhs, + rhs, &buffer, ) .map_err(MetalError::from)?; @@ -1585,12 +1591,10 @@ impl MetalStorage { &device.kernels, kernel_name, lhs_l.dims(), - &self.buffer, + lhs, lhs_l.stride(), - lhs_l.start_offset() * self.dtype.size_in_bytes(), - &rhs.buffer, + rhs, rhs_l.stride(), - rhs_l.start_offset() * rhs.dtype.size_in_bytes(), &buffer, ) .map_err(MetalError::from)?; |