summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/metal_backend.rs147
-rw-r--r--candle-metal-kernels/src/conv.metal153
-rw-r--r--candle-metal-kernels/src/lib.rs104
-rw-r--r--candle-metal-kernels/src/tests.rs69
-rw-r--r--candle-metal-kernels/src/unary.metal12
5 files changed, 399 insertions, 86 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index 27b2824f..1813f276 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -782,12 +782,72 @@ impl BackendStorage for MetalStorage {
fn conv1d(
&self,
- _l: &Layout,
- _kernel: &Self,
- _kernel_l: &Layout,
- _params: &ParamsConv1D,
+ layout: &Layout,
+ kernel: &Self,
+ kernel_l: &Layout,
+ params: &ParamsConv1D,
) -> Result<Self> {
- crate::bail!("conv1d metal")
+ let device = self.device().clone();
+ let shape = layout.shape();
+ let dims = shape.dims();
+ let strides = layout.stride();
+
+ let stride = params.stride;
+ let dilation = params.dilation;
+ let padding = params.padding;
+ let k_size = params.k_size;
+ let l_out = (dims[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
+ let dst_el = dims[0] * l_out * dims[1] * k_size;
+ let dst = self
+ .device
+ .new_buffer(dst_el, self.dtype, "conv1d_im2col")?;
+ let command_buffer = self.device.command_buffer()?;
+ let name = match self.dtype {
+ DType::F32 => "im2col1d_f32",
+ dtype => crate::bail!("conv1d metal {dtype:?} not implemented"),
+ };
+ candle_metal_kernels::call_im2col1d_strided(
+ &self.device.device,
+ &command_buffer,
+ &self.device.kernels,
+ name,
+ layout.shape().dims(),
+ strides,
+ (k_size, stride, padding, dilation),
+ &self.buffer,
+ layout.start_offset() * self.dtype.size_in_bytes(),
+ &dst,
+ )
+ .map_err(MetalError::from)?;
+ let col = Self {
+ buffer: dst,
+ device,
+ dtype: self.dtype,
+ };
+ let l_out = params.l_out();
+ let b = params.b_size;
+ let n = params.c_out;
+ let k = params.k_size * params.c_in;
+ let m = l_out;
+ let col_l = Layout::contiguous((b, m, k));
+ let res = if kernel_l.is_contiguous() {
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ } else {
+ // Make the kernel contiguous if not already the case.
+ let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ };
+ let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
+ let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ res.copy_strided_src(&mut res_t, 0, &res_l)?;
+ Ok(res_t)
}
fn conv_transpose1d(
@@ -802,12 +862,79 @@ impl BackendStorage for MetalStorage {
fn conv2d(
&self,
- _l: &Layout,
- _kernel: &Self,
- _kernel_l: &Layout,
- _params: &ParamsConv2D,
+ layout: &Layout,
+ kernel: &Self,
+ kernel_l: &Layout,
+ params: &ParamsConv2D,
) -> Result<Self> {
- crate::bail!("conv2d metal")
+ let device = self.device().clone();
+ let shape = layout.shape();
+ let dims = shape.dims();
+
+ let stride = params.stride;
+ let dilation = params.dilation;
+ let padding = params.padding;
+ let h_k = params.k_h;
+ let w_k = params.k_w;
+ let h = dims[2];
+ let w = dims[3];
+ let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1;
+ let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1;
+ let dst_el = dims[0] * h_out * w_out * dims[1] * h_k * w_k;
+
+ let dst = self
+ .device
+ .new_buffer(dst_el, self.dtype, "conv2d_im2col")?;
+ let command_buffer = self.device.command_buffer()?;
+ let name = match self.dtype {
+ DType::F32 => "im2col_f32",
+ dtype => crate::bail!("conv1d metal {dtype:?} not implemented"),
+ };
+ candle_metal_kernels::call_im2col_strided(
+ &self.device.device,
+ &command_buffer,
+ &self.device.kernels,
+ name,
+ layout.shape().dims(),
+ layout.stride(),
+ (h_k, w_k, stride, padding, dilation),
+ &self.buffer,
+ layout.start_offset() * self.dtype.size_in_bytes(),
+ &dst,
+ )
+ .map_err(MetalError::from)?;
+ let col = Self {
+ buffer: dst,
+ device,
+ dtype: self.dtype,
+ };
+ let h_out = params.out_h();
+ let w_out = params.out_w();
+ let b = params.b_size;
+ let n = params.c_out;
+ let k = params.k_h * params.k_w * params.c_in;
+ let m = h_out * w_out;
+ let col_l = Layout::contiguous((b, m, k));
+ let res = if kernel_l.is_contiguous() {
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ } else {
+ // Make the kernel contiguous if not already the case.
+ let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
+ kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
+ let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
+ .transpose(1, 2)?
+ .broadcast_as((b, k, n))?;
+ col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
+ };
+ let res_l = Layout::contiguous((b, h_out, w_out, n))
+ .transpose(1, 2)?
+ .transpose(1, 3)?;
+ let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
+ res.copy_strided_src(&mut res_t, 0, &res_l)?;
+ Ok(res_t)
}
fn conv_transpose2d(
diff --git a/candle-metal-kernels/src/conv.metal b/candle-metal-kernels/src/conv.metal
new file mode 100644
index 00000000..49141771
--- /dev/null
+++ b/candle-metal-kernels/src/conv.metal
@@ -0,0 +1,153 @@
+template <typename T>
+METAL_FUNC void im2col(
+ constant size_t &dst_numel,
+ constant size_t &h_out,
+ constant size_t &w_out,
+ constant size_t &h_k,
+ constant size_t &w_k,
+ constant size_t &stride,
+ constant size_t &padding,
+ constant size_t &dilation,
+ constant size_t *src_dims,
+ constant size_t *src_strides,
+ device const T *src,
+ device T *dst,
+ uint tid [[ thread_position_in_grid ]]
+) {
+ // dst: (b_size, h_out, w_out, c_in, h_k, w_k)
+ // src: (b_size, c_in, h_in, w_in)
+ if (tid >= dst_numel) {
+ return;
+ }
+ const size_t b_in = src_dims[0];
+ const size_t c_in = src_dims[1];
+ const size_t h_in = src_dims[2];
+ const size_t w_in = src_dims[3];
+
+ const size_t dst_s4 = w_k;
+ const size_t dst_s3 = h_k * dst_s4;
+ const size_t dst_s2 = c_in * dst_s3;
+ const size_t dst_s1 = w_out * dst_s2;
+ const size_t dst_s0 = h_out * dst_s1;
+
+ size_t tmp_tid = tid;
+ const size_t b_idx = tmp_tid / dst_s0;
+ tmp_tid -= b_idx * dst_s0;
+ const size_t h_idx = tmp_tid / dst_s1;
+ tmp_tid -= h_idx * dst_s1;
+ const size_t w_idx = tmp_tid / dst_s2;
+ tmp_tid -= w_idx * dst_s2;
+ const size_t c_idx = tmp_tid / dst_s3;
+ tmp_tid -= c_idx * dst_s3;
+ const size_t h_k_idx = tmp_tid / dst_s4;
+ tmp_tid -= h_k_idx * dst_s4;
+ const size_t w_k_idx = tmp_tid;
+ size_t src_h_idx = h_idx * stride + h_k_idx * dilation;
+ size_t src_w_idx = w_idx * stride + w_k_idx * dilation;
+ if (src_h_idx < padding || src_h_idx >= h_in + padding) {
+ dst[tid] = static_cast<T>(0);
+ }
+ else if (src_w_idx < padding || src_w_idx >= w_in + padding) {
+ dst[tid] = static_cast<T>(0);
+ }
+ else {
+ src_h_idx -= padding;
+ src_w_idx -= padding;
+ const size_t src_i =
+ b_idx * src_strides[0]
+ + c_idx * src_strides[1]
+ + src_h_idx * src_strides[2]
+ + src_w_idx * src_strides[3];
+ dst[tid] = src[src_i];
+ }
+}
+
+template <typename T>
+METAL_FUNC void im2col1d(
+ constant size_t &dst_numel,
+ constant size_t &l_out,
+ constant size_t &l_k,
+ constant size_t &stride,
+ constant size_t &padding,
+ constant size_t &dilation,
+ constant size_t *src_dims,
+ constant size_t *src_strides,
+ device const T *src,
+ device T *dst,
+ uint tid [[ thread_position_in_grid ]]
+) {
+ // dst: (b_size, l_out, c_in, l_k)
+ // src: (b_size, c_in, l_in)
+ if (tid >= dst_numel) {
+ return;
+ }
+ const size_t b_in = src_dims[0];
+ const size_t c_in = src_dims[1];
+ const size_t l_in = src_dims[2];
+
+ const size_t dst_s2 = l_k;
+ const size_t dst_s1 = c_in * dst_s2;
+ const size_t dst_s0 = l_out * dst_s1;
+
+ size_t tmp_dst_i = tid;
+ const size_t b_idx = tmp_dst_i / dst_s0;
+ tmp_dst_i -= b_idx * dst_s0;
+ const size_t l_idx = tmp_dst_i / dst_s1;
+ tmp_dst_i -= l_idx * dst_s1;
+ const size_t c_idx = tmp_dst_i / dst_s2;
+ tmp_dst_i -= c_idx * dst_s2;
+ const size_t l_k_idx = tmp_dst_i;
+ size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
+ if (src_l_idx < padding || src_l_idx >= l_in + padding) {
+ dst[tid] = static_cast<T>(0);
+ }
+ else {
+ src_l_idx -= padding;
+ const size_t src_i = b_idx * src_strides[0] + c_idx * src_strides[1] + src_l_idx * src_strides[2];
+ dst[tid] = src[src_i];
+ }
+}
+
+#define IM2COL_OP(T, FN_NAME) \
+kernel void FN_NAME( \
+ constant size_t &dst_numel, \
+ constant size_t &h_out, \
+ constant size_t &w_out, \
+ constant size_t &h_k, \
+ constant size_t &w_k, \
+ constant size_t &stride, \
+ constant size_t &padding, \
+ constant size_t &dilation, \
+ constant size_t *src_dims, \
+ constant size_t *src_strides, \
+ device const T *src, \
+ device T *dst, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ im2col<T>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
+} \
+
+#define IM2COL1D_OP(T, FN_NAME) \
+kernel void FN_NAME( \
+ constant size_t &dst_numel, \
+ constant size_t &l_out, \
+ constant size_t &l_k, \
+ constant size_t &stride, \
+ constant size_t &padding, \
+ constant size_t &dilation, \
+ constant size_t *src_dims, \
+ constant size_t *src_strides, \
+ device const T *src, \
+ device T *dst, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
+} \
+
+IM2COL_OP(float, im2col_f32)
+IM2COL_OP(uint8_t, im2col_u8)
+IM2COL_OP(uint32_t, im2col_u32)
+
+IM2COL1D_OP(float, im2col1d_f32)
+IM2COL1D_OP(uint8_t, im2col1d_u8)
+IM2COL1D_OP(uint32_t, im2col1d_u32)
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 0418c96c..d126aa42 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -13,6 +13,7 @@ const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal");
+const CONV: &str = include_str!("conv.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
/// Most kernels apply similarly across the tensors
@@ -115,6 +116,7 @@ pub enum Source {
Cast,
Reduce,
Mfa,
+ Conv,
}
macro_rules! ops{
@@ -225,6 +227,7 @@ impl Kernels {
Source::Indexing => INDEXING,
Source::Cast => CAST,
Source::Reduce => REDUCE,
+ Source::Conv => CONV,
Source::Mfa => panic!("Invalid lib"),
}
}
@@ -1298,7 +1301,7 @@ pub fn call_gemm(
let fused_activation = false;
let fused_bias = false;
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
- let m_simd = 16;
+ let m_simd = 8;
let n_simd = 8;
let k_simd = 64;
let m_splits = 1;
@@ -1307,7 +1310,7 @@ pub fn call_gemm(
} else {
let m_simd = 40;
let n_simd = 40;
- let k_simd = 8;
+ let k_simd = 32;
let m_splits = 1;
let n_splits = 1;
(m_simd, n_simd, k_simd, m_splits, n_splits)
@@ -1418,6 +1421,103 @@ pub fn call_gemm(
Ok(())
}
+#[allow(clippy::too_many_arguments)]
+pub fn call_im2col1d_strided(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ shape: &[usize],
+ strides: &[usize],
+ (k_size, stride, padding, dilation): (usize, usize, usize, usize),
+ input: &Buffer,
+ input_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
+ let l_out = (shape[2] + 2 * padding - dilation * (k_size - 1) - 1) / stride + 1;
+ let dst_el = shape[0] * l_out * shape[1] * k_size;
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+ set_params!(
+ encoder,
+ (
+ dst_el,
+ l_out,
+ k_size,
+ stride,
+ padding,
+ dilation,
+ shape,
+ strides,
+ (input, input_offset),
+ output
+ )
+ );
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_im2col_strided(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ shape: &[usize],
+ strides: &[usize],
+ (h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize),
+ input: &Buffer,
+ input_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
+
+ let h = shape[2];
+ let w = shape[3];
+ let h_out = (h + 2 * padding - dilation * (h_k - 1) - 1) / stride + 1;
+ let w_out = (w + 2 * padding - dilation * (w_k - 1) - 1) / stride + 1;
+
+ let dst_el = shape[0] * h_out * w_out * shape[1] * h_k * w_k;
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+ set_params!(
+ encoder,
+ (
+ dst_el,
+ h_out,
+ w_out,
+ h_k,
+ w_k,
+ stride,
+ padding,
+ dilation,
+ shape,
+ strides,
+ (input, input_offset),
+ output
+ )
+ );
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+
+ Ok(())
+}
+
fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
}
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index 1b3153b1..c955abca 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -1,6 +1,6 @@
use super::*;
use half::{bf16, f16};
-use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
+use metal::{Device, MTLResourceOptions};
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
@@ -486,73 +486,6 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
}
#[test]
-fn index_add() {
- let device = Device::system_default().expect("no device found");
-
- let options = CompileOptions::new();
- let library = device.new_library_with_source(INDEXING, &options).unwrap();
-
- let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
- let right = [1.0f32; 15];
- let index = [0u32, 4, 2];
- let ids_dim_size = index.len() as u32;
- let dst_dim_size: u32 = 15;
- let left_size: u32 = 3;
- let right_size: u32 = 3;
-
- let function = library.get_function("ia_u32_f32", None).unwrap();
- let pipeline = device
- .new_compute_pipeline_state_with_function(&function)
- .unwrap();
-
- let command_queue = device.new_command_queue();
- let command_buffer = command_queue.new_command_buffer();
- let encoder = command_buffer.new_compute_command_encoder();
-
- encoder.set_compute_pipeline_state(&pipeline);
-
- let index_buffer = new_buffer(&device, &index);
- let inputs_buffer = new_buffer(&device, &left);
- let outputs_buffer = new_buffer(&device, &right);
-
- set_params!(
- encoder,
- (
- &index_buffer,
- &inputs_buffer,
- &outputs_buffer,
- ids_dim_size,
- left_size,
- dst_dim_size,
- right_size
- )
- );
-
- let grid_size = MTLSize {
- width: right.len() as NSUInteger,
- height: 1,
- depth: 1,
- };
-
- let thread_group_size = MTLSize {
- width: pipeline.max_total_threads_per_threadgroup(),
- height: 1,
- depth: 1,
- };
-
- encoder.dispatch_thread_groups(grid_size, thread_group_size);
- encoder.end_encoding();
- command_buffer.commit();
- command_buffer.wait_until_completed();
-
- let expected = vec![
- 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
- ];
- let result: Vec<f32> = read_to_vec(&outputs_buffer, right.len());
- assert_eq!(result, expected);
-}
-
-#[test]
fn cos_f16() {
let v: Vec<f16> = [1.0f32, 2.0, 3.0]
.iter()
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index 553bc506..04fa37a9 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -64,12 +64,12 @@ kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *input, \
device TYPENAME *output, \
- uint thread_position_in_grid [[ thread_position_in_grid ]] \
+ uint tid [[ thread_position_in_grid ]] \
) { \
- if (thread_position_in_grid >= dim) { \
+ if (tid >= dim) { \
return; \
} \
- output[thread_position_in_grid] = TYPENAME(FN(float(input[thread_position_in_grid]))); \
+ output[tid] = TYPENAME(FN(float(input[tid]))); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@@ -78,12 +78,12 @@ kernel void FN_NAME_STRIDED( \
constant size_t *strides, \
device const TYPENAME *input, \
device TYPENAME *output, \
- uint thread_position_in_grid [[ thread_position_in_grid ]] \
+ uint tid [[ thread_position_in_grid ]] \
) { \
- if (thread_position_in_grid >= dim) { \
+ if (tid >= dim) { \
return; \
} \
- output[thread_position_in_grid] = TYPENAME(FN(float(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)]))); \
+ output[tid] = TYPENAME(FN(float(input[get_strided_index(tid, num_dims, dims, strides)]))); \
}
#define UNARY_OP(NAME) \