summaryrefslogtreecommitdiff
path: root/candle-core/src/device.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/device.rs')
-rw-r--r--candle-core/src/device.rs17
1 files changed, 17 insertions, 0 deletions
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 9c39d27a..846c62ce 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -289,6 +289,23 @@ impl Device {
}
}
+ pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
+ match self {
+ Device::Cpu => {
+ let storage = CpuDevice.alloc_uninit(shape, dtype)?;
+ Ok(Storage::Cpu(storage))
+ }
+ Device::Cuda(device) => {
+ let storage = device.alloc_uninit(shape, dtype)?;
+ Ok(Storage::Cuda(storage))
+ }
+ Device::Metal(device) => {
+ let storage = device.alloc_uninit(shape, dtype)?;
+ Ok(Storage::Metal(storage))
+ }
+ }
+ }
+
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),