summaryrefslogtreecommitdiff
path: root/candle-core/src/device.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/device.rs')
-rw-r--r--candle-core/src/device.rs8
1 files changed, 8 insertions, 0 deletions
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 91925b57..18aa61af 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -138,6 +138,14 @@ impl Device {
}
}
+ pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> {
+ match self {
+ Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"),
+ Self::Cpu => crate::bail!("expected a metal device, got cpu"),
+ Self::Metal(d) => Ok(d),
+ }
+ }
+
pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
}