summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-21 13:09:42 +0100
committerGitHub <noreply@github.com>2024-03-21 13:09:42 +0100
commitec97c98e81707c8f66db6be22d2df7c8791c55b8 (patch)
tree70c7e70f0333b387ef8d1bd6de7209641ba53549 /candle-core
parentbb3ee48039ed040da48def94f57a6cf1eb0e7911 (diff)
downloadcandle-ec97c98e81707c8f66db6be22d2df7c8791c55b8.tar.gz
candle-ec97c98e81707c8f66db6be22d2df7c8791c55b8.tar.bz2
candle-ec97c98e81707c8f66db6be22d2df7c8791c55b8.zip
Async tensor copying. (#1900)
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/src/backend.rs2
-rw-r--r--candle-core/src/cpu_backend.rs4
-rw-r--r--candle-core/src/cuda_backend.rs37
-rw-r--r--candle-core/src/device.rs8
-rw-r--r--candle-core/src/dummy_cuda_backend.rs4
-rw-r--r--candle-core/src/dummy_metal_backend.rs4
-rw-r--r--candle-core/src/metal_backend.rs4
7 files changed, 59 insertions, 4 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index ea1ac1a9..c63aad54 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -129,6 +129,8 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
+ fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
+
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 1504d5b8..fa48577c 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -2814,6 +2814,10 @@ impl BackendDevice for CpuDevice {
Ok(s.clone())
}
+ fn storage_from_cpu_storage_owned(&self, s: CpuStorage) -> Result<Self::Storage> {
+ Ok(s)
+ }
+
fn new(_: usize) -> Result<Self> {
Ok(Self)
}
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 8954fc33..fec37c39 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -420,6 +420,43 @@ impl BackendDevice for CudaDevice {
device: self.clone(),
})
}
+
+ fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
+ let slice = match storage {
+ CpuStorage::U8(storage) => {
+ let data = self.htod_copy(storage).w()?;
+ CudaStorageSlice::U8(data)
+ }
+ CpuStorage::U32(storage) => {
+ let data = self.htod_copy(storage).w()?;
+ CudaStorageSlice::U32(data)
+ }
+ CpuStorage::I64(storage) => {
+ let data = self.htod_copy(storage).w()?;
+ CudaStorageSlice::I64(data)
+ }
+ CpuStorage::BF16(storage) => {
+ let data = self.htod_copy(storage).w()?;
+ CudaStorageSlice::BF16(data)
+ }
+ CpuStorage::F16(storage) => {
+ let data = self.htod_copy(storage).w()?;
+ CudaStorageSlice::F16(data)
+ }
+ CpuStorage::F32(storage) => {
+ let data = self.htod_copy(storage).w()?;
+ CudaStorageSlice::F32(data)
+ }
+ CpuStorage::F64(storage) => {
+ let data = self.htod_copy(storage).w()?;
+ CudaStorageSlice::F64(data)
+ }
+ };
+ Ok(CudaStorage {
+ slice,
+ device: self.clone(),
+ })
+ }
}
#[derive(Debug)]
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 1e33021b..9c39d27a 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -294,12 +294,12 @@ impl Device {
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
Device::Cuda(device) => {
let storage = array.to_cpu_storage();
- let storage = device.storage_from_cpu_storage(&storage)?;
+ let storage = device.storage_from_cpu_storage_owned(storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = array.to_cpu_storage();
- let storage = device.storage_from_cpu_storage(&storage)?;
+ let storage = device.storage_from_cpu_storage_owned(storage)?;
Ok(Storage::Metal(storage))
}
}
@@ -310,12 +310,12 @@ impl Device {
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
Device::Cuda(device) => {
let storage = S::to_cpu_storage_owned(data);
- let storage = device.storage_from_cpu_storage(&storage)?;
+ let storage = device.storage_from_cpu_storage_owned(storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = S::to_cpu_storage_owned(data);
- let storage = device.storage_from_cpu_storage(&storage)?;
+ let storage = device.storage_from_cpu_storage_owned(storage)?;
Ok(Storage::Metal(storage))
}
}
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index 43d55fa4..d4887f19 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -214,6 +214,10 @@ impl crate::backend::BackendDevice for CudaDevice {
Err(Error::NotCompiledWithCudaSupport)
}
+ fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithCudaSupport)
}
diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs
index 791ec153..33c6c9fe 100644
--- a/candle-core/src/dummy_metal_backend.rs
+++ b/candle-core/src/dummy_metal_backend.rs
@@ -226,6 +226,10 @@ impl crate::backend::BackendDevice for MetalDevice {
Err(Error::NotCompiledWithMetalSupport)
}
+ fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage> {
+ Err(Error::NotCompiledWithMetalSupport)
+ }
+
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index acc6c445..c4245652 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -1867,6 +1867,10 @@ impl BackendDevice for MetalDevice {
))
}
+ fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<Self::Storage> {
+ self.storage_from_cpu_storage(&storage)
+ }
+
fn rand_uniform(
&self,
shape: &Shape,