summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/metal_backend.rs35
-rw-r--r--candle-metal-kernels/src/conv.metal60
-rw-r--r--candle-metal-kernels/src/lib.rs44
3 files changed, 137 insertions, 2 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index 1813f276..6d8afab1 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -959,8 +959,39 @@ impl BackendStorage for MetalStorage {
crate::bail!("upsample_nearest1d metal")
}
- fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
- crate::bail!("upsample_nearest2d metal")
+ fn upsample_nearest2d(&self, inp_l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
+ // let inp = &inp.slice(inp_l.start_offset()..);
+ let shape = inp_l.shape();
+ let dims = shape.dims();
+ let strides = inp_l.stride();
+ if dims.len() != 4 {
+ crate::bail!("unexpected input shape for upsample {dims:?}")
+ }
+ let name = match self.dtype {
+ DType::F32 => "upsample_nearest2d_f32",
+ dtype => crate::bail!("Not implemented {dtype:?} for upsample_nearest2d, metal"),
+ };
+
+ let dst_el = out_w * out_h * dims[0] * dims[1];
+ let buffer = self
+ .device
+ .new_buffer(dst_el, self.dtype, "upsample_nearest2d")?;
+ let command_buffer = self.device.command_buffer()?;
+ candle_metal_kernels::call_upsample_nearest_2d(
+ &self.device.device,
+ &command_buffer,
+ &self.device.kernels,
+ name,
+ dims,
+ strides,
+ out_w,
+ out_h,
+ &self.buffer,
+ inp_l.start_offset() * self.dtype.size_in_bytes(),
+ &buffer,
+ )
+ .map_err(MetalError::from)?;
+ Ok(Self::new(buffer, self.device.clone(), self.dtype))
}
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
diff --git a/candle-metal-kernels/src/conv.metal b/candle-metal-kernels/src/conv.metal
index 49141771..dca53161 100644
--- a/candle-metal-kernels/src/conv.metal
+++ b/candle-metal-kernels/src/conv.metal
@@ -108,6 +108,47 @@ METAL_FUNC void im2col1d(
}
}
+template <typename T>
+METAL_FUNC void upsample_nearest2d(
+ constant size_t &w_out,
+ constant size_t &h_out,
+ constant float &w_scale,
+ constant float &h_scale,
+ constant size_t *src_dims,
+ constant size_t *src_s,
+ device const T *src,
+ device T *dst,
+ uint tid [[ thread_position_in_grid ]]
+) {
+ // src: (b_size, c_in, w_in, h_in)
+
+ const size_t c = src_dims[1];
+ const size_t w_in = src_dims[2];
+ const size_t h_in = src_dims[3];
+
+ if (tid >= src_dims[0] * c * w_out * h_out) {
+ return;
+ }
+
+ // TODO: Improve this.
+ const size_t b_idx = tid / (w_out * h_out * c);
+ const size_t c_idx = (tid / (w_out * h_out)) % c;
+ const size_t dst_w = (tid / h_out) % w_out;
+ const size_t dst_h = tid % h_out;
+
+ size_t src_w = static_cast<size_t>(dst_w * w_scale);
+ size_t src_h = static_cast<size_t>(dst_h * h_scale);
+ if (src_w >= w_in) {
+ src_w = w_in - 1;
+ }
+ if (src_h >= h_in) {
+ src_h = h_in - 1;
+ }
+
+ const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3];
+ dst[tid] = src[src_i];
+}
+
#define IM2COL_OP(T, FN_NAME) \
kernel void FN_NAME( \
constant size_t &dst_numel, \
@@ -143,6 +184,21 @@ kernel void FN_NAME( \
) { \
im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
} \
+
+#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \
+kernel void FN_NAME( \
+ constant size_t &w_out, \
+ constant size_t &h_out, \
+ constant float &w_scale, \
+ constant float &h_scale, \
+ constant size_t *dims, \
+ constant size_t *strides, \
+ device const TYPENAME *src, \
+ device TYPENAME *dst, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ upsample_nearest2d<TYPENAME>(w_out, h_out, w_scale, h_scale, dims, strides, src, dst, tid); \
+} \
IM2COL_OP(float, im2col_f32)
IM2COL_OP(uint8_t, im2col_u8)
@@ -151,3 +207,7 @@ IM2COL_OP(uint32_t, im2col_u32)
IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32)
+
+UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32)
+UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8)
+UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32)
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index d126aa42..dd97a86d 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1518,6 +1518,50 @@ pub fn call_im2col_strided(
Ok(())
}
+#[allow(clippy::too_many_arguments)]
+pub fn call_upsample_nearest_2d(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ shape: &[usize],
+ strides: &[usize],
+ out_w: usize,
+ out_h: usize,
+ input: &Buffer,
+ input_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
+ let dst_el = out_w * out_h * shape[0] * shape[1];
+ let scale_w = shape[2] as f32 / out_w as f32;
+ let scale_h = shape[3] as f32 / out_h as f32;
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+ set_params!(
+ encoder,
+ (
+ out_w,
+ out_h,
+ scale_w,
+ scale_h,
+ shape,
+ strides,
+ (input, input_offset),
+ output
+ )
+ );
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+
+ Ok(())
+}
+
fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
}