summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-10-27 15:20:37 +0100
committerGitHub <noreply@github.com>2024-10-27 15:20:37 +0100
commit0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c (patch)
treedaa4ae5e627a95c6c54cd8eabec32e15d71309be
parent594d984f9cf79207f3beb6114ddf73cbc8427b56 (diff)
downloadcandle-0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c.tar.gz
candle-0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c.tar.bz2
candle-0e2c8c17fba0c1ba720e3bb50d2d4dec19cef07c.zip
UG metal integration. (#2580)
-rw-r--r--Cargo.toml1
-rw-r--r--candle-core/Cargo.toml3
-rw-r--r--candle-core/src/custom_op.rs48
-rw-r--r--candle-core/src/device.rs8
-rw-r--r--candle-core/src/metal_backend/device.rs22
-rw-r--r--candle-core/tests/custom_op_tests.rs16
-rw-r--r--candle-metal-kernels/src/lib.rs2
-rw-r--r--candle-metal-kernels/src/utils.rs10
8 files changed, 92 insertions, 18 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 64e1460e..f27ec933 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -72,6 +72,7 @@ tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
ug = "0.0.2"
ug-cuda = "0.0.2"
+ug-metal = "0.0.2"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "1.1.1", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index 8ea2b08c..4ffc869f 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -30,6 +30,7 @@ safetensors = { workspace = true }
thiserror = { workspace = true }
ug = { workspace = true }
ug-cuda = { workspace = true, optional = true }
+ug-metal = { workspace = true, optional = true }
yoke = { workspace = true }
zip = { workspace = true }
@@ -45,7 +46,7 @@ cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
-metal = ["dep:metal", "dep:candle-metal-kernels"]
+metal = ["dep:metal", "dep:candle-metal-kernels", "dep:ug-metal"]
[[bench]]
name = "bench_main"
diff --git a/candle-core/src/custom_op.rs b/candle-core/src/custom_op.rs
index 276e3658..c0d97d67 100644
--- a/candle-core/src/custom_op.rs
+++ b/candle-core/src/custom_op.rs
@@ -380,6 +380,8 @@ pub struct UgIOp1 {
name: &'static str,
#[cfg(feature = "cuda")]
func: cudarc::driver::CudaFunction,
+ #[cfg(feature = "metal")]
+ func: metal::ComputePipelineState,
}
impl UgIOp1 {
@@ -395,7 +397,13 @@ impl UgIOp1 {
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
}
- #[cfg(not(feature = "cuda"))]
+ #[cfg(feature = "metal")]
+ {
+ let device = device.as_metal_device()?;
+ let func = device.compile(name, kernel)?;
+ Ok(Self { name, func })
+ }
+ #[cfg(not(any(feature = "cuda", feature = "metal")))]
{
Ok(Self { name })
}
@@ -408,11 +416,43 @@ impl InplaceOp1 for UgIOp1 {
}
fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
- crate::bail!("ug ops are only supported on cuda at the moment")
+ crate::bail!("ug ops are only supported on metal/cuda at the moment")
}
- fn metal_fwd(&self, _: &mut MetalStorage, _: &Layout) -> Result<()> {
- crate::bail!("ug ops are only supported on cuda at the moment")
+ #[cfg(feature = "metal")]
+ fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> {
+ use crate::backend::BackendStorage;
+ use candle_metal_kernels::utils::EncoderProvider;
+
+ let elem_count = layout.shape().elem_count();
+ if sto.dtype() != crate::DType::F32 {
+ // TODO: support more dtypes.
+ crate::bail!("input is not a f32 tensor")
+ }
+ let device = sto.device();
+ println!("here");
+ let command_buffer = device.command_buffer()?;
+ let command_buffer = &command_buffer;
+ let encoder = command_buffer.encoder();
+ let encoder = encoder.as_ref();
+ encoder.set_compute_pipeline_state(&self.func);
+ let (g, b) = if elem_count % 32 == 0 {
+ (elem_count / 32, 32)
+ } else {
+ (elem_count, 1)
+ };
+ let grid_dims = metal::MTLSize {
+ width: g as u64,
+ height: 1,
+ depth: 1,
+ };
+ let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1);
+ candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize));
+
+ encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write);
+ encoder.dispatch_threads(grid_dims, group_dims);
+
+ Ok(())
}
#[cfg(feature = "cuda")]
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 91925b57..18aa61af 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -138,6 +138,14 @@ impl Device {
}
}
+ pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> {
+ match self {
+ Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"),
+ Self::Cpu => crate::bail!("expected a metal device, got cpu"),
+ Self::Metal(d) => Ok(d),
+ }
+ }
+
pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
}
diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs
index 29b8995b..46be6ce4 100644
--- a/candle-core/src/metal_backend/device.rs
+++ b/candle-core/src/metal_backend/device.rs
@@ -144,6 +144,28 @@ impl MetalDevice {
self.use_mlx_mm = use_mlx_mm
}
+ pub fn compile(
+ &self,
+ func_name: &'static str,
+ kernel: ug::lang::ssa::Kernel,
+ ) -> Result<metal::ComputePipelineState> {
+ let mut buf = vec![];
+ ug_metal::code_gen::gen(&mut buf, func_name, &kernel)?;
+ let metal_code = String::from_utf8(buf)?;
+ let lib = self
+ .device
+ .new_library_with_source(&metal_code, &metal::CompileOptions::new())
+ .map_err(MetalError::from)?;
+ let func = lib
+ .get_function(func_name, None)
+ .map_err(MetalError::from)?;
+ let pl = self
+ .device
+ .new_compute_pipeline_state_with_function(&func)
+ .map_err(MetalError::from)?;
+ Ok(pl)
+ }
+
pub fn id(&self) -> DeviceId {
self.id
}
diff --git a/candle-core/tests/custom_op_tests.rs b/candle-core/tests/custom_op_tests.rs
index f2c01aca..3572a4c9 100644
--- a/candle-core/tests/custom_op_tests.rs
+++ b/candle-core/tests/custom_op_tests.rs
@@ -144,7 +144,7 @@ fn inplace_op1() -> Result<()> {
Ok(())
}
-#[cfg(feature = "cuda")]
+#[cfg(any(feature = "cuda", feature = "metal"))]
#[allow(clippy::approx_constant)]
#[test]
fn ug_op() -> Result<()> {
@@ -160,15 +160,21 @@ fn ug_op() -> Result<()> {
let opts: ug::lower_op::Opts = Default::default();
kernel.lower(&opts.with_global(0, 12))?
};
- let device = Device::new_cuda(0)?;
+ let device = if candle_core::utils::cuda_is_available() {
+ Device::new_cuda(0)?
+ } else if candle_core::utils::metal_is_available() {
+ Device::new_metal(0)?
+ } else {
+ candle_core::bail!("metal/cuda is mandatory for this test")
+ };
let op = candle_core::UgIOp1::new("test", kernel, &device)?;
let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?;
t.inplace_op1(&op)?;
assert_eq!(
- to_vec1_round(&t, 4)?,
+ to_vec1_round(&t, 2)?,
&[
- 1.0, 2.7183, 7.3891, 20.0855, 54.5982, 148.4132, 403.4287, 1096.6334, 2980.9578,
- 8103.0806, 22026.469, 59874.133
+ 1.0, 2.72, 7.39, 20.09, 54.6, 148.41, 403.43, 1096.63, 2980.96, 8103.08, 22026.47,
+ 59874.13
]
);
Ok(())
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index be616009..222ae8ad 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -6,7 +6,7 @@ use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;
-mod utils;
+pub mod utils;
pub use utils::BufferOffset;
use utils::{get_block_dims, linear_split, EncoderProvider};
diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs
index d2cc09f4..0092ecfa 100644
--- a/candle-metal-kernels/src/utils.rs
+++ b/candle-metal-kernels/src/utils.rs
@@ -24,7 +24,7 @@ pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (M
}
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
-pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
+pub fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
let mut pows0 = 0u64;
let mut pows1 = 0u64;
let mut pows2 = 0u64;
@@ -61,18 +61,14 @@ pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
}
}
-pub(crate) fn set_param<P: EncoderParam>(
- encoder: &ComputeCommandEncoderRef,
- position: u64,
- data: P,
-) {
+pub fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
<P as EncoderParam>::set_param(encoder, position, data)
}
/// Helper functions to create the various objects on the compute command encoder
/// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
-pub(crate) trait EncoderParam {
+pub trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
macro_rules! primitive {