diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-04-23 13:23:27 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-23 13:23:27 +0200 |
commit | 8a05743a21768405217576a1b9557936be74ed90 (patch) | |
tree | 8c6eab8ca1f5496d02b874e7d0a340ee70afe9c5 /candle-core/src/metal_backend | |
parent | b2e816752bb3b81ed5daaf4b623c3b5e6c0f7b67 (diff) | |
download | candle-8a05743a21768405217576a1b9557936be74ed90.tar.gz candle-8a05743a21768405217576a1b9557936be74ed90.tar.bz2 candle-8a05743a21768405217576a1b9557936be74ed90.zip |
Add StorageRef. (#2113)
* Add the storage-ref bits.
* Add the metal implementation.
Diffstat (limited to 'candle-core/src/metal_backend')
-rw-r--r-- | candle-core/src/metal_backend/mod.rs | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 12dba381..1396899b 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1,7 +1,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; -use crate::{CpuStorage, DType, Layout, Result, Shape}; +use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels}; use metal::{Buffer, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; @@ -1787,6 +1787,19 @@ impl BackendDevice for MetalDevice { self.storage_from_cpu_storage(&cpu_storage) } + fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> { + let (count, buffer) = match T::cpu_storage_ref(s) { + CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)), + }; + Ok(Self::Storage::new(buffer?, self.clone(), count, T::DTYPE)) + } + fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> { let (count, buffer) = match storage { CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), |