summaryrefslogtreecommitdiff
path: root/candle-core/src/metal_backend
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-23 13:23:27 +0200
committerGitHub <noreply@github.com>2024-04-23 13:23:27 +0200
commit8a05743a21768405217576a1b9557936be74ed90 (patch)
tree8c6eab8ca1f5496d02b874e7d0a340ee70afe9c5 /candle-core/src/metal_backend
parentb2e816752bb3b81ed5daaf4b623c3b5e6c0f7b67 (diff)
downloadcandle-8a05743a21768405217576a1b9557936be74ed90.tar.gz
candle-8a05743a21768405217576a1b9557936be74ed90.tar.bz2
candle-8a05743a21768405217576a1b9557936be74ed90.zip
Add StorageRef. (#2113)
* Add the storage-ref bits. * Add the metal implementation.
Diffstat (limited to 'candle-core/src/metal_backend')
-rw-r--r--candle-core/src/metal_backend/mod.rs15
1 files changed, 14 insertions, 1 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs
index 12dba381..1396899b 100644
--- a/candle-core/src/metal_backend/mod.rs
+++ b/candle-core/src/metal_backend/mod.rs
@@ -1,7 +1,7 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
-use crate::{CpuStorage, DType, Layout, Result, Shape};
+use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
use metal::{Buffer, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
@@ -1787,6 +1787,19 @@ impl BackendDevice for MetalDevice {
self.storage_from_cpu_storage(&cpu_storage)
}
+ fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
+ let (count, buffer) = match T::cpu_storage_ref(s) {
+ CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
+ CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
+ CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
+ CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
+ CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
+ CpuStorageRef::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
+ CpuStorageRef::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
+ };
+ Ok(Self::Storage::new(buffer?, self.clone(), count, T::DTYPE))
+ }
+
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
let (count, buffer) = match storage {
CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),