diff options
Diffstat (limited to 'candle-core/src/device.rs')
-rw-r--r-- | candle-core/src/device.rs | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 9c39d27a..846c62ce 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -289,6 +289,23 @@ impl Device { } } + pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> { + match self { + Device::Cpu => { + let storage = CpuDevice.alloc_uninit(shape, dtype)?; + Ok(Storage::Cpu(storage)) + } + Device::Cuda(device) => { + let storage = device.alloc_uninit(shape, dtype)?; + Ok(Storage::Cuda(storage)) + } + Device::Metal(device) => { + let storage = device.alloc_uninit(shape, dtype)?; + Ok(Storage::Metal(storage)) + } + } + } + pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> { match self { Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())), |