diff options
Diffstat (limited to 'candle-core/src/metal_backend')
-rw-r--r-- | candle-core/src/metal_backend/mod.rs | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 12dba381..1396899b 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1,7 +1,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; -use crate::{CpuStorage, DType, Layout, Result, Shape}; +use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels}; use metal::{Buffer, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; @@ -1787,6 +1787,19 @@ impl BackendDevice for MetalDevice { self.storage_from_cpu_storage(&cpu_storage) } + fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> { + let (count, buffer) = match T::cpu_storage_ref(s) { + CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)), + }; + Ok(Self::Storage::new(buffer?, self.clone(), count, T::DTYPE)) + } + fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> { let (count, buffer) = match storage { CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), |