diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/custom_op.rs | 48 | ||||
-rw-r--r-- | candle-core/src/device.rs | 8 | ||||
-rw-r--r-- | candle-core/src/metal_backend/device.rs | 22 |
3 files changed, 74 insertions, 4 deletions
diff --git a/candle-core/src/custom_op.rs b/candle-core/src/custom_op.rs index 276e3658..c0d97d67 100644 --- a/candle-core/src/custom_op.rs +++ b/candle-core/src/custom_op.rs @@ -380,6 +380,8 @@ pub struct UgIOp1 { name: &'static str, #[cfg(feature = "cuda")] func: cudarc::driver::CudaFunction, + #[cfg(feature = "metal")] + func: metal::ComputePipelineState, } impl UgIOp1 { @@ -395,7 +397,13 @@ impl UgIOp1 { let func = device.compile(name, kernel)?; Ok(Self { name, func }) } - #[cfg(not(feature = "cuda"))] + #[cfg(feature = "metal")] + { + let device = device.as_metal_device()?; + let func = device.compile(name, kernel)?; + Ok(Self { name, func }) + } + #[cfg(not(any(feature = "cuda", feature = "metal")))] { Ok(Self { name }) } @@ -408,11 +416,43 @@ impl InplaceOp1 for UgIOp1 { } fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> { - crate::bail!("ug ops are only supported on cuda at the moment") + crate::bail!("ug ops are only supported on metal/cuda at the moment") } - fn metal_fwd(&self, _: &mut MetalStorage, _: &Layout) -> Result<()> { - crate::bail!("ug ops are only supported on cuda at the moment") + #[cfg(feature = "metal")] + fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> { + use crate::backend::BackendStorage; + use candle_metal_kernels::utils::EncoderProvider; + + let elem_count = layout.shape().elem_count(); + if sto.dtype() != crate::DType::F32 { + // TODO: support more dtypes. + crate::bail!("input is not a f32 tensor") + } + let device = sto.device(); + println!("here"); + let command_buffer = device.command_buffer()?; + let command_buffer = &command_buffer; + let encoder = command_buffer.encoder(); + let encoder = encoder.as_ref(); + encoder.set_compute_pipeline_state(&self.func); + let (g, b) = if elem_count % 32 == 0 { + (elem_count / 32, 32) + } else { + (elem_count, 1) + }; + let grid_dims = metal::MTLSize { + width: g as u64, + height: 1, + depth: 1, + }; + let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1); + candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize)); + + encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write); + encoder.dispatch_threads(grid_dims, group_dims); + + Ok(()) } #[cfg(feature = "cuda")] diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 91925b57..18aa61af 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -138,6 +138,14 @@ impl Device { } } + pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> { + match self { + Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"), + Self::Cpu => crate::bail!("expected a metal device, got cpu"), + Self::Metal(d) => Ok(d), + } + } + pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> { Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?)) } diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 29b8995b..46be6ce4 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -144,6 +144,28 @@ impl MetalDevice { self.use_mlx_mm = use_mlx_mm } + pub fn compile( + &self, + func_name: &'static str, + kernel: ug::lang::ssa::Kernel, + ) -> Result<metal::ComputePipelineState> { + let mut buf = vec![]; + ug_metal::code_gen::gen(&mut buf, func_name, &kernel)?; + let metal_code = String::from_utf8(buf)?; + let lib = self + .device + .new_library_with_source(&metal_code, &metal::CompileOptions::new()) + .map_err(MetalError::from)?; + let func = lib + .get_function(func_name, None) + .map_err(MetalError::from)?; + let pl = self + .device + .new_compute_pipeline_state_with_function(&func) + .map_err(MetalError::from)?; + Ok(pl) + } + pub fn id(&self) -> DeviceId { self.id } |