diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-21 13:09:42 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-21 13:09:42 +0100 |
commit | ec97c98e81707c8f66db6be22d2df7c8791c55b8 (patch) | |
tree | 70c7e70f0333b387ef8d1bd6de7209641ba53549 /candle-core | |
parent | bb3ee48039ed040da48def94f57a6cf1eb0e7911 (diff) | |
download | candle-ec97c98e81707c8f66db6be22d2df7c8791c55b8.tar.gz candle-ec97c98e81707c8f66db6be22d2df7c8791c55b8.tar.bz2 candle-ec97c98e81707c8f66db6be22d2df7c8791c55b8.zip |
Async tensor copying. (#1900)
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/src/backend.rs | 2 | ||||
-rw-r--r-- | candle-core/src/cpu_backend.rs | 4 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 37 | ||||
-rw-r--r-- | candle-core/src/device.rs | 8 | ||||
-rw-r--r-- | candle-core/src/dummy_cuda_backend.rs | 4 | ||||
-rw-r--r-- | candle-core/src/dummy_metal_backend.rs | 4 | ||||
-rw-r--r-- | candle-core/src/metal_backend.rs | 4 |
7 files changed, 59 insertions, 4 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index ea1ac1a9..c63aad54 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -129,6 +129,8 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone { fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>; + fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>; + fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>; fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>; diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 1504d5b8..fa48577c 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -2814,6 +2814,10 @@ impl BackendDevice for CpuDevice { Ok(s.clone()) } + fn storage_from_cpu_storage_owned(&self, s: CpuStorage) -> Result<Self::Storage> { + Ok(s) + } + fn new(_: usize) -> Result<Self> { Ok(Self) } diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 8954fc33..fec37c39 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -420,6 +420,43 @@ impl BackendDevice for CudaDevice { device: self.clone(), }) } + + fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> { + let slice = match storage { + CpuStorage::U8(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::U8(data) + } + CpuStorage::U32(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::U32(data) + } + CpuStorage::I64(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::I64(data) + } + CpuStorage::BF16(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::BF16(data) + } + CpuStorage::F16(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F16(data) + } + CpuStorage::F32(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F32(data) + } + CpuStorage::F64(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } } #[derive(Debug)] diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 1e33021b..9c39d27a 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -294,12 +294,12 @@ impl Device { Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())), Device::Cuda(device) => { let storage = array.to_cpu_storage(); - let storage = device.storage_from_cpu_storage(&storage)?; + let storage = device.storage_from_cpu_storage_owned(storage)?; Ok(Storage::Cuda(storage)) } Device::Metal(device) => { let storage = array.to_cpu_storage(); - let storage = device.storage_from_cpu_storage(&storage)?; + let storage = device.storage_from_cpu_storage_owned(storage)?; Ok(Storage::Metal(storage)) } } @@ -310,12 +310,12 @@ impl Device { Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))), Device::Cuda(device) => { let storage = S::to_cpu_storage_owned(data); - let storage = device.storage_from_cpu_storage(&storage)?; + let storage = device.storage_from_cpu_storage_owned(storage)?; Ok(Storage::Cuda(storage)) } Device::Metal(device) => { let storage = S::to_cpu_storage_owned(data); - let storage = device.storage_from_cpu_storage(&storage)?; + let storage = device.storage_from_cpu_storage_owned(storage)?; Ok(Storage::Metal(storage)) } } diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 43d55fa4..d4887f19 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -214,6 +214,10 @@ impl crate::backend::BackendDevice for CudaDevice { Err(Error::NotCompiledWithCudaSupport) } + fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> { + Err(Error::NotCompiledWithCudaSupport) + } + fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index 791ec153..33c6c9fe 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -226,6 +226,10 @@ impl crate::backend::BackendDevice for MetalDevice { Err(Error::NotCompiledWithMetalSupport) } + fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> { + Err(Error::NotCompiledWithMetalSupport) + } + fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> { Err(Error::NotCompiledWithMetalSupport) } diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index acc6c445..c4245652 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -1867,6 +1867,10 @@ impl BackendDevice for MetalDevice { )) } + fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<Self::Storage> { + self.storage_from_cpu_storage(&storage) + } + fn rand_uniform( &self, shape: &Shape, |