diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-10-27 15:20:37 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-27 15:20:37 +0100 |
commit | 0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c (patch) | |
tree | daa4ae5e627a95c6c54cd8eabec32e15d71309be /candle-core/src/custom_op.rs | |
parent | 594d984f9cf79207f3beb6114ddf73cbc8427b56 (diff) | |
download | candle-0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c.tar.gz candle-0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c.tar.bz2 candle-0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c.zip |
UG metal integration. (#2580)
Diffstat (limited to 'candle-core/src/custom_op.rs')
-rw-r--r-- | candle-core/src/custom_op.rs | 48 |
1 files changed, 44 insertions, 4 deletions
diff --git a/candle-core/src/custom_op.rs b/candle-core/src/custom_op.rs index 276e3658..c0d97d67 100644 --- a/candle-core/src/custom_op.rs +++ b/candle-core/src/custom_op.rs @@ -380,6 +380,8 @@ pub struct UgIOp1 { name: &'static str, #[cfg(feature = "cuda")] func: cudarc::driver::CudaFunction, + #[cfg(feature = "metal")] + func: metal::ComputePipelineState, } impl UgIOp1 { @@ -395,7 +397,13 @@ impl UgIOp1 { let func = device.compile(name, kernel)?; Ok(Self { name, func }) } - #[cfg(not(feature = "cuda"))] + #[cfg(feature = "metal")] + { + let device = device.as_metal_device()?; + let func = device.compile(name, kernel)?; + Ok(Self { name, func }) + } + #[cfg(not(any(feature = "cuda", feature = "metal")))] { Ok(Self { name }) } @@ -408,11 +416,43 @@ impl InplaceOp1 for UgIOp1 { } fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> { - crate::bail!("ug ops are only supported on cuda at the moment") + crate::bail!("ug ops are only supported on metal/cuda at the moment") } - fn metal_fwd(&self, _: &mut MetalStorage, _: &Layout) -> Result<()> { - crate::bail!("ug ops are only supported on cuda at the moment") + #[cfg(feature = "metal")] + fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> { + use crate::backend::BackendStorage; + use candle_metal_kernels::utils::EncoderProvider; + + let elem_count = layout.shape().elem_count(); + if sto.dtype() != crate::DType::F32 { + // TODO: support more dtypes. + crate::bail!("input is not a f32 tensor") + } + let device = sto.device(); + println!("here"); + let command_buffer = device.command_buffer()?; + let command_buffer = &command_buffer; + let encoder = command_buffer.encoder(); + let encoder = encoder.as_ref(); + encoder.set_compute_pipeline_state(&self.func); + let (g, b) = if elem_count % 32 == 0 { + (elem_count / 32, 32) + } else { + (elem_count, 1) + }; + let grid_dims = metal::MTLSize { + width: g as u64, + height: 1, + depth: 1, + }; + let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1); + candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize)); + + encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write); + encoder.dispatch_threads(grid_dims, group_dims); + + Ok(()) } #[cfg(feature = "cuda")] |