summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r--candle-core/src/cuda_backend.rs37
1 files changed, 37 insertions, 0 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 8954fc33..fec37c39 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -420,6 +420,43 @@ impl BackendDevice for CudaDevice {
device: self.clone(),
})
}
+
+ fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
+ let slice = match storage {
+ CpuStorage::U8(storage) => {
+ let data = self.htod_copy(storage).w()?;
+ CudaStorageSlice::U8(data)
+ }
+ CpuStorage::U32(storage) => {
+ let data = self.htod_copy(storage).w()?;
+ CudaStorageSlice::U32(data)
+ }
+ CpuStorage::I64(storage) => {
+ let data = self.htod_copy(storage).w()?;
+ CudaStorageSlice::I64(data)
+ }
+ CpuStorage::BF16(storage) => {
+ let data = self.htod_copy(storage).w()?;
+ CudaStorageSlice::BF16(data)
+ }
+ CpuStorage::F16(storage) => {
+ let data = self.htod_copy(storage).w()?;
+ CudaStorageSlice::F16(data)
+ }
+ CpuStorage::F32(storage) => {
+ let data = self.htod_copy(storage).w()?;
+ CudaStorageSlice::F32(data)
+ }
+ CpuStorage::F64(storage) => {
+ let data = self.htod_copy(storage).w()?;
+ CudaStorageSlice::F64(data)
+ }
+ };
+ Ok(CudaStorage {
+ slice,
+ device: self.clone(),
+ })
+ }
}
#[derive(Debug)]