summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-05-25 11:03:23 +0200
committerGitHub <noreply@github.com>2024-05-25 11:03:23 +0200
commit0814dfd148474f436bb43314fc41639a7b429aab (patch)
tree7301c78b9e635bc09a268f3293a1f8f8b2997e6f /candle-metal-kernels
parent3ceca9901a5ebc4ded3ac2cd793d0125f7a12562 (diff)
downloadcandle-0814dfd148474f436bb43314fc41639a7b429aab.tar.gz
candle-0814dfd148474f436bb43314fc41639a7b429aab.tar.bz2
candle-0814dfd148474f436bb43314fc41639a7b429aab.zip
Add a metal kernel for col2im1d. (#2214)
* Add a metal kernel for col2im1d. * Enable the col2im variant. * Bugfix. * Revert the quantized tweak.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/conv.metal65
-rw-r--r--candle-metal-kernels/src/lib.rs33
2 files changed, 97 insertions, 1 deletions
diff --git a/candle-metal-kernels/src/conv.metal b/candle-metal-kernels/src/conv.metal
index 8fdd0e5f..5348a0f0 100644
--- a/candle-metal-kernels/src/conv.metal
+++ b/candle-metal-kernels/src/conv.metal
@@ -69,6 +69,50 @@ METAL_FUNC void im2col(
}
template <typename T>
+METAL_FUNC void col2im1d(
+ constant size_t &dst_el,
+ constant size_t &l_out,
+ constant size_t &l_in,
+ constant size_t &c_out,
+ constant size_t &k_size,
+ constant size_t &stride,
+ device const T *src,
+ device T *dst,
+ uint dst_i [[ thread_position_in_grid ]]
+) {
+ // src: (b_size, l_in, c_out, l_k)
+ // dst: (b_size, c_out, l_out)
+ if (dst_i >= dst_el) {
+ return;
+ }
+
+ const size_t dst_s0 = c_out * l_out;
+ const size_t dst_s1 = l_out;
+ const size_t src_s0 = c_out * k_size * l_in;
+ const size_t src_s1 = c_out * k_size;
+ const size_t src_s2 = k_size;
+
+ size_t tmp_dst_i = dst_i;
+ const size_t b_idx = tmp_dst_i / dst_s0;
+ tmp_dst_i -= b_idx * dst_s0;
+ const size_t c_idx = tmp_dst_i / dst_s1;
+ tmp_dst_i -= c_idx * dst_s1;
+ const int l_out_idx = tmp_dst_i;
+
+ dst[dst_i] = static_cast<T>(0);
+
+ int l_in_idx = l_out_idx / stride;
+ int k0 = l_out_idx - l_in_idx * stride;
+ // l_out_idx = l_in_idx * stride + k0
+ for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) {
+ if (l_in_idx < l_in) {
+ const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0;
+ dst[dst_i] += src[src_i];
+ }
+ }
+}
+
+template <typename T>
METAL_FUNC void im2col1d(
constant size_t &dst_numel,
constant size_t &l_out,
@@ -190,6 +234,21 @@ kernel void FN_NAME( \
) { \
im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
} \
+
+#define COL2IM1D_OP(T, FN_NAME) \
+kernel void FN_NAME( \
+ constant size_t &dst_el, \
+ constant size_t &l_out, \
+ constant size_t &l_in, \
+ constant size_t &c_out, \
+ constant size_t &k_size, \
+ constant size_t &stride, \
+ device const T *src, \
+ device T *dst, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ col2im1d<T>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst, tid); \
+} \
#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \
kernel void FN_NAME( \
@@ -493,6 +552,10 @@ IM2COL_OP(uint32_t, im2col_u32)
IM2COL_OP(bfloat, im2col_bf16)
#endif
+COL2IM1D_OP(float, col2im1d_f32)
+COL2IM1D_OP(uint8_t, col2im1d_u8)
+COL2IM1D_OP(uint32_t, col2im1d_u32)
+
IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32)
@@ -533,4 +596,4 @@ CONVT2D_OP(float, float, conv_transpose2d_f32)
CONVT2D_OP(half, float, conv_transpose2d_f16)
#if defined(__HAVE_BFLOAT__)
CONVT1D_OP(bfloat, float, conv_transpose2d_bf16)
-#endif \ No newline at end of file
+#endif
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index aa157a2f..1815dd32 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1652,6 +1652,39 @@ pub fn call_im2col1d_strided(
}
#[allow(clippy::too_many_arguments)]
+pub fn call_col2im1d(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ shape: &[usize],
+ k_size: usize,
+ stride: usize,
+ input: BufferOffset,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
+ let l_in = shape[1];
+ let c_out = shape[2];
+ let l_out = (l_in - 1) * stride + k_size;
+ let dst_el = shape[0] * c_out * l_out;
+
+ let encoder = command_buffer.new_compute_command_encoder();
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
+ encoder.set_compute_pipeline_state(&pipeline);
+ set_params!(
+ encoder,
+ (dst_el, l_out, l_in, c_out, k_size, stride, &input, output)
+ );
+ encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
pub fn call_im2col_strided(
device: &Device,
command_buffer: &CommandBufferRef,