summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r--candle-metal-kernels/src/lib.rs44
1 files changed, 44 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index d126aa42..dd97a86d 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1518,6 +1518,50 @@ pub fn call_im2col_strided(
Ok(())
}
+#[allow(clippy::too_many_arguments)]
+pub fn call_upsample_nearest_2d(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ shape: &[usize],
+ strides: &[usize],
+ out_w: usize,
+ out_h: usize,
+ input: &Buffer,
+ input_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
+ let dst_el = out_w * out_h * shape[0] * shape[1];
+ let scale_w = shape[2] as f32 / out_w as f32;
+ let scale_h = shape[3] as f32 / out_h as f32;
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+ set_params!(
+ encoder,
+ (
+ out_w,
+ out_h,
+ scale_w,
+ scale_h,
+ shape,
+ strides,
+ (input, input_offset),
+ output
+ )
+ );
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+
+ Ok(())
+}
+
fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
}