summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/custom_op.rs48
-rw-r--r--candle-core/src/device.rs8
-rw-r--r--candle-core/src/metal_backend/device.rs22
3 files changed, 74 insertions, 4 deletions
diff --git a/candle-core/src/custom_op.rs b/candle-core/src/custom_op.rs
index 276e3658..c0d97d67 100644
--- a/candle-core/src/custom_op.rs
+++ b/candle-core/src/custom_op.rs
@@ -380,6 +380,8 @@ pub struct UgIOp1 {
name: &'static str,
#[cfg(feature = "cuda")]
func: cudarc::driver::CudaFunction,
+ #[cfg(feature = "metal")]
+ func: metal::ComputePipelineState,
}
impl UgIOp1 {
@@ -395,7 +397,13 @@ impl UgIOp1 {
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
}
- #[cfg(not(feature = "cuda"))]
+ #[cfg(feature = "metal")]
+ {
+ let device = device.as_metal_device()?;
+ let func = device.compile(name, kernel)?;
+ Ok(Self { name, func })
+ }
+ #[cfg(not(any(feature = "cuda", feature = "metal")))]
{
Ok(Self { name })
}
@@ -408,11 +416,43 @@ impl InplaceOp1 for UgIOp1 {
}
fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
- crate::bail!("ug ops are only supported on cuda at the moment")
+ crate::bail!("ug ops are only supported on metal/cuda at the moment")
}
- fn metal_fwd(&self, _: &mut MetalStorage, _: &Layout) -> Result<()> {
- crate::bail!("ug ops are only supported on cuda at the moment")
+ #[cfg(feature = "metal")]
+ fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> {
+ use crate::backend::BackendStorage;
+ use candle_metal_kernels::utils::EncoderProvider;
+
+ let elem_count = layout.shape().elem_count();
+ if sto.dtype() != crate::DType::F32 {
+ // TODO: support more dtypes.
+ crate::bail!("input is not a f32 tensor")
+ }
+ let device = sto.device();
+ println!("here");
+ let command_buffer = device.command_buffer()?;
+ let command_buffer = &command_buffer;
+ let encoder = command_buffer.encoder();
+ let encoder = encoder.as_ref();
+ encoder.set_compute_pipeline_state(&self.func);
+ let (g, b) = if elem_count % 32 == 0 {
+ (elem_count / 32, 32)
+ } else {
+ (elem_count, 1)
+ };
+ let grid_dims = metal::MTLSize {
+ width: g as u64,
+ height: 1,
+ depth: 1,
+ };
+ let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1);
+ candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize));
+
+ encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write);
+ encoder.dispatch_threads(grid_dims, group_dims);
+
+ Ok(())
}
#[cfg(feature = "cuda")]
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 91925b57..18aa61af 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -138,6 +138,14 @@ impl Device {
}
}
+ pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> {
+ match self {
+ Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"),
+ Self::Cpu => crate::bail!("expected a metal device, got cpu"),
+ Self::Metal(d) => Ok(d),
+ }
+ }
+
pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
}
diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs
index 29b8995b..46be6ce4 100644
--- a/candle-core/src/metal_backend/device.rs
+++ b/candle-core/src/metal_backend/device.rs
@@ -144,6 +144,28 @@ impl MetalDevice {
self.use_mlx_mm = use_mlx_mm
}
+ pub fn compile(
+ &self,
+ func_name: &'static str,
+ kernel: ug::lang::ssa::Kernel,
+ ) -> Result<metal::ComputePipelineState> {
+ let mut buf = vec![];
+ ug_metal::code_gen::gen(&mut buf, func_name, &kernel)?;
+ let metal_code = String::from_utf8(buf)?;
+ let lib = self
+ .device
+ .new_library_with_source(&metal_code, &metal::CompileOptions::new())
+ .map_err(MetalError::from)?;
+ let func = lib
+ .get_function(func_name, None)
+ .map_err(MetalError::from)?;
+ let pl = self
+ .device
+ .new_compute_pipeline_state_with_function(&func)
+ .map_err(MetalError::from)?;
+ Ok(pl)
+ }
+
pub fn id(&self) -> DeviceId {
self.id
}