diff options
Diffstat (limited to 'candle-metal-kernels/src/lib.rs')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index d126aa42..dd97a86d 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1518,6 +1518,50 @@ pub fn call_im2col_strided( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_upsample_nearest_2d( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + strides: &[usize], + out_w: usize, + out_h: usize, + input: &Buffer, + input_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; + let dst_el = out_w * out_h * shape[0] * shape[1]; + let scale_w = shape[2] as f32 / out_w as f32; + let scale_h = shape[3] as f32 / out_h as f32; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + out_w, + out_h, + scale_w, + scale_h, + shape, + strides, + (input, input_offset), + output + ) + ); + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + + Ok(()) +} + fn divide(m: usize, b: usize) -> NSUInteger { ((m + b - 1) / b) as NSUInteger } |