summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-11-10 08:35:48 +0100
committerGitHub <noreply@github.com>2023-11-10 08:35:48 +0100
commit26c4e5bf1d10532c9b681f07a7b08b2c84844bee (patch)
tree39b8114a0f2a2975cf9bcb1684129fa5f2e9ef69
parent18d30005c577c029dec611a0bdd0260946468cdb (diff)
downloadcandle-26c4e5bf1d10532c9b681f07a7b08b2c84844bee.tar.gz
candle-26c4e5bf1d10532c9b681f07a7b08b2c84844bee.tar.bz2
candle-26c4e5bf1d10532c9b681f07a7b08b2c84844bee.zip
Metal part 1 - Scaffolding for metal. (#1308)
* Metal part 1 - Scaffolding for metal. * Remove tracing.
-rw-r--r--Cargo.toml1
-rw-r--r--candle-core/Cargo.toml2
-rw-r--r--candle-core/src/device.rs51
-rw-r--r--candle-core/src/display.rs2
-rw-r--r--candle-core/src/dummy_metal_backend.rs223
-rw-r--r--candle-core/src/error.rs8
-rw-r--r--candle-core/src/lib.rs7
-rw-r--r--candle-core/src/op.rs44
-rw-r--r--candle-core/src/storage.rs108
-rw-r--r--candle-core/src/tensor.rs9
-rw-r--r--candle-core/src/utils.rs4
-rw-r--r--candle-examples/src/lib.rs17
-rw-r--r--candle-pyo3/src/lib.rs13
13 files changed, 473 insertions, 16 deletions
diff --git a/Cargo.toml b/Cargo.toml
index a1981993..a0d597e7 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -60,6 +60,7 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
+metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
[profile.release-with-debug]
inherits = "release"
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index 8e57127a..c5521c92 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -13,6 +13,7 @@ readme = "README.md"
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
+metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
@@ -39,3 +40,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
+metal = ["dep:metal"]
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 9dfcd7d5..de57c03a 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -8,12 +8,14 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
pub enum DeviceLocation {
Cpu,
Cuda { gpu_id: usize },
+ Metal,
}
#[derive(Debug, Clone)]
pub enum Device {
Cpu,
Cuda(crate::CudaDevice),
+ Metal(crate::MetalDevice),
}
pub trait NdArray {
@@ -128,10 +130,15 @@ impl Device {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}
+ pub fn new_metal(ordinal: usize) -> Result<Self> {
+ Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
+ }
+
pub fn set_seed(&self, seed: u64) -> Result<()> {
match self {
- Self::Cpu => crate::cpu_backend::CpuDevice.set_seed(seed),
+ Self::Cpu => CpuDevice.set_seed(seed),
Self::Cuda(c) => c.set_seed(seed),
+ Self::Metal(m) => m.set_seed(seed),
}
}
@@ -147,21 +154,20 @@ impl Device {
match self {
Self::Cpu => DeviceLocation::Cpu,
Self::Cuda(device) => device.location(),
+ Device::Metal(device) => device.location(),
}
}
pub fn is_cpu(&self) -> bool {
- match self {
- Self::Cpu => true,
- Self::Cuda(_) => false,
- }
+ matches!(self, Self::Cpu)
}
pub fn is_cuda(&self) -> bool {
- match self {
- Self::Cpu => false,
- Self::Cuda(_) => true,
- }
+ matches!(self, Self::Cuda(_))
+ }
+
+ pub fn is_metal(&self) -> bool {
+ matches!(self, Self::Metal(_))
}
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
@@ -194,6 +200,11 @@ impl Device {
Ok(Storage::Cuda(storage))
}
}
+ Device::Metal(_device) => {
+ // let storage = device.rand_uniform(shape, dtype, lo, up)?;
+ // Ok(Storage::Metal(storage))
+ crate::bail!("Metal rand_uniform not implemented")
+ }
}
}
@@ -228,6 +239,10 @@ impl Device {
Ok(Storage::Cuda(storage))
}
}
+ Device::Metal(device) => {
+ let storage = device.rand_normal(shape, dtype, mean, std)?;
+ Ok(Storage::Metal(storage))
+ }
}
}
@@ -250,6 +265,10 @@ impl Device {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
+ Device::Metal(device) => {
+ let storage = device.ones_impl(shape, dtype)?;
+ Ok(Storage::Metal(storage))
+ }
}
}
@@ -263,6 +282,10 @@ impl Device {
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
+ Device::Metal(device) => {
+ let storage = device.zeros_impl(shape, dtype)?;
+ Ok(Storage::Metal(storage))
+ }
}
}
@@ -274,6 +297,11 @@ impl Device {
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
+ Device::Metal(device) => {
+ let storage = array.to_cpu_storage();
+ let storage = device.storage_from_cpu_storage(&storage)?;
+ Ok(Storage::Metal(storage))
+ }
}
}
@@ -285,6 +313,11 @@ impl Device {
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
+ Device::Metal(device) => {
+ let storage = S::to_cpu_storage_owned(data);
+ let storage = device.storage_from_cpu_storage(&storage)?;
+ Ok(Storage::Metal(storage))
+ }
}
}
}
diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs
index b497699b..215c28f6 100644
--- a/candle-core/src/display.rs
+++ b/candle-core/src/display.rs
@@ -14,6 +14,7 @@ impl Tensor {
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
+ _ => todo!(),
};
write!(f, "Tensor[")?;
@@ -476,6 +477,7 @@ impl std::fmt::Display for Tensor {
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
+ crate::DeviceLocation::Metal => todo!(),
};
write!(
diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs
new file mode 100644
index 00000000..e9d92331
--- /dev/null
+++ b/candle-core/src/dummy_metal_backend.rs
@@ -0,0 +1,223 @@
+#![allow(dead_code)]
+use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
+use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
+
+#[derive(Debug, Clone)]
+pub struct MetalDevice;
+
+#[derive(Debug)]
+pub struct MetalStorage;
+
+#[derive(thiserror::Error, Debug)]
+pub enum MetalError {
+ #[error("{0}")]
+ Message(String),
+}
+
+impl From<String> for MetalError {
+ fn from(e: String) -> Self {
+ MetalError::Message(e)
+ }
+}
+
+macro_rules! fail {
+ () => {
+ unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
+ };
+}
+
+impl crate::backend::BackendStorage for MetalStorage {
+ type Device = MetalDevice;
+
+ fn try_clone(&self, _: &Layout) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn dtype(&self) -> DType {
+ fail!()
+ }
+
+ fn device(&self) -> &Self::Device {
+ fail!()
+ }
+
+ fn to_cpu_storage(&self) -> Result<CpuStorage> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn conv1d(
+ &self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: &crate::conv::ParamsConv1D,
+ ) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn conv_transpose1d(
+ &self,
+ _l: &Layout,
+ _kernel: &Self,
+ _kernel_l: &Layout,
+ _params: &crate::conv::ParamsConvTranspose1D,
+ ) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn conv2d(
+ &self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: &crate::conv::ParamsConv2D,
+ ) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn conv_transpose2d(
+ &self,
+ _l: &Layout,
+ _kernel: &Self,
+ _kernel_l: &Layout,
+ _params: &crate::conv::ParamsConvTranspose2D,
+ ) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+ fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn scatter_add(
+ &self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: usize,
+ ) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn index_add(
+ &self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: usize,
+ ) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn matmul(
+ &self,
+ _: &Self,
+ _: (usize, usize, usize, usize),
+ _: &Layout,
+ _: &Layout,
+ ) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+}
+
+impl crate::backend::BackendDevice for MetalDevice {
+ type Storage = MetalStorage;
+ fn new(_: usize) -> Result<Self> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn set_seed(&self, _: u64) -> Result<()> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn location(&self) -> crate::DeviceLocation {
+ fail!()
+ }
+
+ fn same_device(&self, _: &Self) -> bool {
+ fail!()
+ }
+
+ fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
+ fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+}
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index 96a2b809..60ddea11 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -1,4 +1,4 @@
-use crate::{DType, DeviceLocation, Layout, Shape};
+use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
#[derive(Debug, Clone)]
pub struct MatMulUnexpectedStriding {
@@ -152,6 +152,9 @@ pub enum Error {
#[error("the candle crate has not been built with cuda support")]
NotCompiledWithCudaSupport,
+ #[error("the candle crate has not been built with metal support")]
+ NotCompiledWithMetalSupport,
+
#[error("cannot find tensor {path}")]
CannotFindTensor { path: String },
@@ -159,6 +162,9 @@ pub enum Error {
#[error(transparent)]
Cuda(Box<dyn std::error::Error + Send + Sync>),
+ #[error("Metal error {0}")]
+ Metal(#[from] MetalError),
+
#[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError),
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index 73830229..da61bdb5 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -49,6 +49,7 @@ mod device;
pub mod display;
mod dtype;
mod dummy_cuda_backend;
+mod dummy_metal_backend;
pub mod error;
mod indexer;
pub mod layout;
@@ -87,6 +88,12 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
#[cfg(not(feature = "cuda"))]
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
+#[cfg(feature = "metal")]
+pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
+
+#[cfg(not(feature = "metal"))]
+pub use dummy_metal_backend::{MetalDevice, MetalError, MetalStorage};
+
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index 1345078c..fbb20f6c 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -1,5 +1,5 @@
#![allow(clippy::redundant_closure_call)]
-use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
+use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
use half::{bf16, f16};
use num_traits::float::Float;
@@ -184,6 +184,18 @@ pub trait CustomOp1 {
))
}
+ /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
+ /// offsets etc so the associated layout should be used to access it.
+ fn metal_fwd(
+ &self,
+ _storage: &MetalStorage,
+ _layout: &Layout,
+ ) -> Result<(MetalStorage, Shape)> {
+ Err(crate::Error::Metal(
+ format!("no metal implementation for {}", self.name()).into(),
+ ))
+ }
+
/// This function takes as argument the argument `arg` used in the forward pass, the result
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
/// The function should return the gradient of the argument.
@@ -219,6 +231,20 @@ pub trait CustomOp2 {
))
}
+ /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
+ /// offsets etc so the associated layout should be used to access it.
+ fn metal_fwd(
+ &self,
+ _: &MetalStorage,
+ _: &Layout,
+ _: &MetalStorage,
+ _: &Layout,
+ ) -> Result<(MetalStorage, Shape)> {
+ Err(crate::Error::Metal(
+ format!("no metal implementation for {}", self.name()).into(),
+ ))
+ }
+
fn bwd(
&self,
_arg1: &Tensor,
@@ -261,6 +287,22 @@ pub trait CustomOp3 {
))
}
+ /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
+ /// offsets etc so the associated layout should be used to access it.
+ fn metal_fwd(
+ &self,
+ _: &MetalStorage,
+ _: &Layout,
+ _: &MetalStorage,
+ _: &Layout,
+ _: &MetalStorage,
+ _: &Layout,
+ ) -> Result<(MetalStorage, Shape)> {
+ Err(crate::Error::Metal(
+ format!("no metal implementation for {}", self.name()).into(),
+ ))
+ }
+
fn bwd(
&self,
_arg1: &Tensor,
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index dc75c02c..9e1a2c1d 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -1,6 +1,6 @@
use crate::backend::BackendStorage;
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
-use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
+use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
// We do not want to implement Clone on Storage as cloning may fail because of
// out of memory. Instead try_clone should be used.
@@ -8,6 +8,7 @@ use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape
pub enum Storage {
Cpu(CpuStorage),
Cuda(CudaStorage),
+ Metal(MetalStorage),
}
impl Storage {
@@ -18,6 +19,10 @@ impl Storage {
let storage = storage.try_clone(layout)?;
Ok(Self::Cuda(storage))
}
+ Self::Metal(storage) => {
+ let storage = storage.try_clone(layout)?;
+ Ok(Self::Metal(storage))
+ }
}
}
@@ -25,6 +30,7 @@ impl Storage {
match self {
Self::Cpu(_) => Device::Cpu,
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
+ Self::Metal(storage) => Device::Metal(storage.device().clone()),
}
}
@@ -32,6 +38,7 @@ impl Storage {
match self {
Self::Cpu(storage) => storage.dtype(),
Self::Cuda(storage) => storage.dtype(),
+ Self::Metal(storage) => storage.dtype(),
}
}
@@ -65,6 +72,10 @@ impl Storage {
let storage = storage.affine(layout, mul, add)?;
Ok(Self::Cuda(storage))
}
+ Self::Metal(storage) => {
+ let storage = storage.affine(layout, mul, add)?;
+ Ok(Self::Metal(storage))
+ }
}
}
@@ -78,6 +89,10 @@ impl Storage {
let storage = storage.powf(layout, alpha)?;
Ok(Self::Cuda(storage))
}
+ Self::Metal(storage) => {
+ let storage = storage.powf(layout, alpha)?;
+ Ok(Self::Metal(storage))
+ }
}
}
@@ -91,6 +106,10 @@ impl Storage {
let storage = storage.elu(layout, alpha)?;
Ok(Self::Cuda(storage))
}
+ Self::Metal(storage) => {
+ let storage = storage.elu(layout, alpha)?;
+ Ok(Self::Metal(storage))
+ }
}
}
@@ -112,6 +131,10 @@ impl Storage {
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
+ (Self::Metal(lhs), Self::Metal(rhs)) => {
+ let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
+ Ok(Self::Metal(storage))
+ }
(lhs, rhs) => {
// Should not happen because of the same device check above but we're defensive
// anyway.
@@ -135,6 +158,10 @@ impl Storage {
let storage = storage.reduce_op(op, layout, s)?;
Ok(Self::Cuda(storage))
}
+ Self::Metal(storage) => {
+ let storage = storage.reduce_op(op, layout, s)?;
+ Ok(Self::Metal(storage))
+ }
}
}
@@ -148,6 +175,10 @@ impl Storage {
let storage = storage.to_dtype(layout, dtype)?;
Ok(Self::Cuda(storage))
}
+ Self::Metal(storage) => {
+ let storage = storage.to_dtype(layout, dtype)?;
+ Ok(Self::Metal(storage))
+ }
}
}
@@ -161,6 +192,10 @@ impl Storage {
let (storage, shape) = c.cuda_fwd(storage, l)?;
Ok((Self::Cuda(storage), shape))
}
+ Self::Metal(storage) => {
+ let (storage, shape) = c.metal_fwd(storage, l)?;
+ Ok((Self::Metal(storage), shape))
+ }
}
}
@@ -181,6 +216,10 @@ impl Storage {
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
Ok((Self::Cuda(s), shape))
}
+ (Self::Metal(s1), Self::Metal(s2)) => {
+ let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
+ Ok((Self::Metal(s), shape))
+ }
_ => unreachable!(),
}
}
@@ -205,6 +244,10 @@ impl Storage {
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
Ok((Self::Cuda(s), shape))
}
+ (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
+ let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
+ Ok((Self::Metal(s), shape))
+ }
_ => unreachable!(),
}
}
@@ -219,6 +262,10 @@ impl Storage {
let storage = storage.unary_impl::<B>(layout)?;
Ok(Self::Cuda(storage))
}
+ Self::Metal(storage) => {
+ let storage = storage.unary_impl::<B>(layout)?;
+ Ok(Self::Metal(storage))
+ }
}
}
@@ -239,6 +286,10 @@ impl Storage {
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
+ (Self::Metal(lhs), Self::Metal(rhs)) => {
+ let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
+ Ok(Self::Metal(storage))
+ }
(lhs, rhs) => {
// Should not happen because of the same device check above but we're defensive
// anyway.
@@ -270,6 +321,10 @@ impl Storage {
let s = inp.conv1d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
+ (Storage::Metal(inp), Storage::Metal(kernel)) => {
+ let s = inp.conv1d(l, kernel, kernel_l, params)?;
+ Ok(Self::Metal(s))
+ }
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@@ -324,6 +379,10 @@ impl Storage {
let s = inp.conv2d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
+ (Storage::Metal(inp), Storage::Metal(kernel)) => {
+ let s = inp.conv2d(l, kernel, kernel_l, params)?;
+ Ok(Self::Metal(s))
+ }
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@@ -351,6 +410,10 @@ impl Storage {
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
+ (Storage::Metal(inp), Storage::Metal(kernel)) => {
+ let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
+ Ok(Self::Metal(s))
+ }
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@@ -375,6 +438,10 @@ impl Storage {
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
Ok(Self::Cuda(storage))
}
+ Self::Metal(storage) => {
+ let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
+ Ok(Self::Metal(storage))
+ }
}
}
@@ -393,6 +460,10 @@ impl Storage {
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
Ok(Self::Cuda(storage))
}
+ Self::Metal(storage) => {
+ let storage = storage.max_pool2d(layout, kernel_size, stride)?;
+ Ok(Self::Metal(storage))
+ }
}
}
@@ -406,6 +477,10 @@ impl Storage {
let storage = storage.upsample_nearest1d(layout, sz)?;
Ok(Self::Cuda(storage))
}
+ Self::Metal(storage) => {
+ let storage = storage.upsample_nearest1d(layout, sz)?;
+ Ok(Self::Metal(storage))
+ }
}
}
@@ -419,6 +494,10 @@ impl Storage {
let storage = storage.upsample_nearest2d(layout, h, w)?;
Ok(Self::Cuda(storage))
}
+ Self::Metal(storage) => {
+ let storage = storage.upsample_nearest2d(layout, h, w)?;
+ Ok(Self::Metal(storage))
+ }
}
}
@@ -442,6 +521,10 @@ impl Storage {
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
Ok(Self::Cuda(storage))
}
+ (Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
+ let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
+ Ok(Self::Metal(storage))
+ }
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@@ -468,6 +551,10 @@ impl Storage {
let storage = s.gather(l, indexes, indexes_l, d)?;
Ok(Self::Cuda(storage))
}
+ (Self::Metal(s), Self::Metal(indexes)) => {
+ let storage = s.gather(l, indexes, indexes_l, d)?;
+ Ok(Self::Metal(storage))
+ }
_ => unreachable!(),
}
}
@@ -492,6 +579,10 @@ impl Storage {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
}
+ (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
+ let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
+ Ok(Self::Metal(storage))
+ }
_ => unreachable!(),
}
}
@@ -516,6 +607,10 @@ impl Storage {
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
}
+ (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
+ let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
+ Ok(Self::Metal(storage))
+ }
_ => unreachable!(),
}
}
@@ -537,6 +632,10 @@ impl Storage {
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
Ok(Self::Cuda(storage))
}
+ (Self::Metal(lhs), Self::Metal(rhs)) => {
+ let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
+ Ok(Self::Metal(storage))
+ }
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@@ -564,6 +663,10 @@ impl Storage {
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
+ (Self::Metal(lhs), Self::Metal(rhs)) => {
+ let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
+ Ok(Self::Metal(storage))
+ }
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@@ -583,6 +686,9 @@ impl Storage {
match (self, dst) {
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
+ (Self::Metal(src), Self::Metal(dst)) => {
+ Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
+ }
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 133b2782..f032a896 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -6,7 +6,7 @@ use crate::op::{
};
use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims};
-use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
+use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
/// Unique identifier for tensors.
@@ -529,6 +529,7 @@ impl Tensor {
match &*self.storage() {
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
+ Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}
@@ -1454,6 +1455,7 @@ impl Tensor {
match &*self.storage() {
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
+ Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}
@@ -1484,6 +1486,7 @@ impl Tensor {
match &*self.storage() {
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
+ Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}
@@ -1524,6 +1527,7 @@ impl Tensor {
match &*self.storage() {
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
+ Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}
@@ -1849,6 +1853,9 @@ impl Tensor {
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
}
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
+ _ => {
+ bail!("not implemented yet")
+ }
};
let op = BackpropOp::new1(self, Op::ToDevice);
let tensor_ = Tensor_ {
diff --git a/candle-core/src/utils.rs b/candle-core/src/utils.rs
index a9c2df0b..78c45a9a 100644
--- a/candle-core/src/utils.rs
+++ b/candle-core/src/utils.rs
@@ -23,6 +23,10 @@ pub fn cuda_is_available() -> bool {
cfg!(feature = "cuda")
}
+pub fn metal_is_available() -> bool {
+ cfg!(feature = "metal")
+}
+
pub fn with_avx() -> bool {
cfg!(target_feature = "avx")
}
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs
index 4ef97f88..dff31b85 100644
--- a/candle-examples/src/lib.rs
+++ b/candle-examples/src/lib.rs
@@ -2,17 +2,28 @@ pub mod coco_classes;
pub mod imagenet;
pub mod token_output_stream;
+use candle::utils::{cuda_is_available, metal_is_available};
use candle::{Device, Result, Tensor};
pub fn device(cpu: bool) -> Result<Device> {
if cpu {
Ok(Device::Cpu)
+ } else if cuda_is_available() {
+ Ok(Device::new_cuda(0)?)
+ } else if metal_is_available() {
+ Ok(Device::new_metal(0)?)
} else {
- let device = Device::cuda_if_available(0)?;
- if !device.is_cuda() {
+ #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
+ {
+ println!(
+ "Running on CPU, to run on GPU(metal), build this example with `--features metal`"
+ );
+ }
+ #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
+ {
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
}
- Ok(device)
+ Ok(Device::Cpu)
}
}
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 05a786ef..b0c623d3 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -71,11 +71,13 @@ impl PyDType {
}
static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
+static METAL_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum PyDevice {
Cpu,
Cuda,
+ Metal,
}
impl PyDevice {
@@ -83,6 +85,7 @@ impl PyDevice {
match device {
Device::Cpu => Self::Cpu,
Device::Cuda(_) => Self::Cuda,
+ Device::Metal(_) => Self::Metal,
}
}
@@ -98,6 +101,15 @@ impl PyDevice {
*device = Some(d.clone());
Ok(d)
}
+ Self::Metal => {
+ let mut device = METAL_DEVICE.lock().unwrap();
+ if let Some(device) = device.as_ref() {
+ return Ok(device.clone());
+ };
+ let d = Device::new_metal(0).map_err(wrap_err)?;
+ *device = Some(d.clone());
+ Ok(d)
+ }
}
}
}
@@ -119,6 +131,7 @@ impl ToPyObject for PyDevice {
let str = match self {
PyDevice::Cpu => "cpu",
PyDevice::Cuda => "cuda",
+ PyDevice::Metal => "metal",
};
str.to_object(py)
}