summaryrefslogtreecommitdiff
path: root/candle-core/src/custom_op.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-10-27 15:20:37 +0100
committerGitHub <noreply@github.com>2024-10-27 15:20:37 +0100
commit0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c (patch)
treedaa4ae5e627a95c6c54cd8eabec32e15d71309be /candle-core/src/custom_op.rs
parent594d984f9cf79207f3beb6114ddf73cbc8427b56 (diff)
downloadcandle-0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c.tar.gz
candle-0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c.tar.bz2
candle-0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c.zip
UG metal integration. (#2580)
Diffstat (limited to 'candle-core/src/custom_op.rs')
-rw-r--r--candle-core/src/custom_op.rs48
1 files changed, 44 insertions, 4 deletions
diff --git a/candle-core/src/custom_op.rs b/candle-core/src/custom_op.rs
index 276e3658..c0d97d67 100644
--- a/candle-core/src/custom_op.rs
+++ b/candle-core/src/custom_op.rs
@@ -380,6 +380,8 @@ pub struct UgIOp1 {
name: &'static str,
#[cfg(feature = "cuda")]
func: cudarc::driver::CudaFunction,
+ #[cfg(feature = "metal")]
+ func: metal::ComputePipelineState,
}
impl UgIOp1 {
@@ -395,7 +397,13 @@ impl UgIOp1 {
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
}
- #[cfg(not(feature = "cuda"))]
+ #[cfg(feature = "metal")]
+ {
+ let device = device.as_metal_device()?;
+ let func = device.compile(name, kernel)?;
+ Ok(Self { name, func })
+ }
+ #[cfg(not(any(feature = "cuda", feature = "metal")))]
{
Ok(Self { name })
}
@@ -408,11 +416,43 @@ impl InplaceOp1 for UgIOp1 {
}
fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
- crate::bail!("ug ops are only supported on cuda at the moment")
+ crate::bail!("ug ops are only supported on metal/cuda at the moment")
}
- fn metal_fwd(&self, _: &mut MetalStorage, _: &Layout) -> Result<()> {
- crate::bail!("ug ops are only supported on cuda at the moment")
+ #[cfg(feature = "metal")]
+ fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> {
+ use crate::backend::BackendStorage;
+ use candle_metal_kernels::utils::EncoderProvider;
+
+ let elem_count = layout.shape().elem_count();
+ if sto.dtype() != crate::DType::F32 {
+ // TODO: support more dtypes.
+ crate::bail!("input is not a f32 tensor")
+ }
+ let device = sto.device();
+ println!("here");
+ let command_buffer = device.command_buffer()?;
+ let command_buffer = &command_buffer;
+ let encoder = command_buffer.encoder();
+ let encoder = encoder.as_ref();
+ encoder.set_compute_pipeline_state(&self.func);
+ let (g, b) = if elem_count % 32 == 0 {
+ (elem_count / 32, 32)
+ } else {
+ (elem_count, 1)
+ };
+ let grid_dims = metal::MTLSize {
+ width: g as u64,
+ height: 1,
+ depth: 1,
+ };
+ let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1);
+ candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize));
+
+ encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write);
+ encoder.dispatch_threads(grid_dims, group_dims);
+
+ Ok(())
}
#[cfg(feature = "cuda")]