diff options
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r-- | candle-core/src/cuda_backend.rs | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 8954fc33..fec37c39 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -420,6 +420,43 @@ impl BackendDevice for CudaDevice { device: self.clone(), }) } + + fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> { + let slice = match storage { + CpuStorage::U8(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::U8(data) + } + CpuStorage::U32(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::U32(data) + } + CpuStorage::I64(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::I64(data) + } + CpuStorage::BF16(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::BF16(data) + } + CpuStorage::F16(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F16(data) + } + CpuStorage::F32(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F32(data) + } + CpuStorage::F64(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } } #[derive(Debug)] |