summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/conv.metal
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/conv.metal')
-rw-r--r--candle-metal-kernels/src/conv.metal78
1 files changed, 78 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/conv.metal b/candle-metal-kernels/src/conv.metal
index 7f7a75cf..a258ae58 100644
--- a/candle-metal-kernels/src/conv.metal
+++ b/candle-metal-kernels/src/conv.metal
@@ -335,6 +335,76 @@ kernel void FN_NAME( \
max_pool2d<TYPENAME>(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \
} \
+
+// Naive implementation of conv_transpose1d.
+template <typename T, typename A>
+METAL_FUNC void conv_transpose1d(
+ constant size_t &l_out,
+ constant size_t &stride,
+ constant size_t &padding,
+ constant size_t &out_padding,
+ constant size_t &dilation,
+ constant size_t *src_dims,
+ constant size_t *src_strides,
+ constant size_t *k_dims,
+ constant size_t *k_strides,
+ device const T *src,
+ device const T *k,
+ device T *dst,
+ uint tid [[ thread_position_in_grid ]]
+) {
+ // src: (b_size, c_in, l_in)
+ // kernel: (c_in, c_out, l_k)
+ const size_t l_k = k_dims[2];
+ const size_t c_out = k_dims[1];
+ const size_t c_in = src_dims[1];
+ const size_t l_in = src_dims[2];
+ if (tid >= src_dims[0] * c_out * l_out) {
+ return;
+ }
+
+ const size_t b_idx = tid / (l_out * c_out);
+ const size_t dst_c_idx = (tid / l_out) % c_out;
+ const size_t out_x = tid % l_out;
+
+ const size_t src_idx0 = b_idx * src_strides[0];
+ A d = 0;
+ for (int k_x = 0; k_x < (int)l_k; ++k_x) {
+ // let out_x = inp_x * p.stride + k_x * p.dilation - p.padding;
+ int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
+ if (inp_x_stride < 0 || inp_x_stride % stride) {
+ continue;
+ }
+ int inp_x = inp_x_stride / stride;
+ if (inp_x >= l_in) continue;
+ for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
+ const size_t src_idx = src_idx0 + src_c_idx * src_strides[1] + inp_x * src_strides[2];
+ const size_t k_idx = src_c_idx * k_strides[0] + dst_c_idx * k_strides[1] + k_x * k_strides[2];
+ d += static_cast<A>(src[src_idx]) * static_cast<A>(k[k_idx]);
+ }
+ }
+ dst[tid] = static_cast<T>(d);
+}
+
+#define CONVT1D_OP(TYPENAME, TYPEACC, FN_NAME) \
+kernel void FN_NAME( \
+ constant size_t &l_out, \
+ constant size_t &stride, \
+ constant size_t &padding, \
+ constant size_t &out_padding, \
+ constant size_t &dilation, \
+ constant size_t *src_dims, \
+ constant size_t *src_strides, \
+ constant size_t *k_dims, \
+ constant size_t *k_strides, \
+ device const TYPENAME *src, \
+ device const TYPENAME *k, \
+ device TYPENAME *dst, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ conv_transpose1d<TYPENAME, TYPEACC>(l_out, stride, padding, out_padding, dilation, src_dims, src_strides, k_dims, k_strides, src, k, dst, tid); \
+} \
+
IM2COL_OP(float, im2col_f32)
IM2COL_OP(uint8_t, im2col_u8)
IM2COL_OP(uint32_t, im2col_u32)
@@ -361,4 +431,12 @@ AVGPOOL2D_OP(uint32_t, uint32_t, avg_pool2d_u32)
AVGPOOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8)
#if defined(__HAVE_BFLOAT__)
AVGPOOL2D_OP(bfloat, float, avg_pool2d_bf16)
+#endif
+
+CONVT1D_OP(float, float, conv_transpose1d_f32)
+CONVT1D_OP(half, float, conv_transpose1d_f16)
+CONVT1D_OP(uint8_t, uint8_t, conv_transpose1d_u8)
+CONVT1D_OP(uint32_t, uint32_t, conv_transpose1d_u32)
+#if defined(__HAVE_BFLOAT__)
+CONVT1D_OP(bfloat, float, conv_transpose1d_bf16)
#endif \ No newline at end of file