summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-10-02 21:30:58 +0200
committerGitHub <noreply@github.com>2024-10-02 21:30:58 +0200
commit7b60bda4ed8c9d861396fe74307d6c77281522ef (patch)
tree5fd41d30db814c61c74acd9a1546dd039f7c183a /candle-core
parent936300678d588c6525594ef2578737e0c19ecf07 (diff)
downloadcandle-7b60bda4ed8c9d861396fe74307d6c77281522ef.tar.gz
candle-7b60bda4ed8c9d861396fe74307d6c77281522ef.tar.bz2
candle-7b60bda4ed8c9d861396fe74307d6c77281522ef.zip
Add support for cuda streams. (#2532)
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/src/cuda_backend/device.rs14
-rw-r--r--candle-core/src/device.rs4
-rw-r--r--candle-core/src/dummy_cuda_backend.rs6
3 files changed, 24 insertions, 0 deletions
diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs
index 0aa58cac..89fe44a6 100644
--- a/candle-core/src/cuda_backend/device.rs
+++ b/candle-core/src/cuda_backend/device.rs
@@ -144,6 +144,20 @@ impl CudaDevice {
}
}
+impl CudaDevice {
+ pub fn new_with_stream(ordinal: usize) -> Result<Self> {
+ let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?;
+ let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
+ let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
+ Ok(Self {
+ id: DeviceId::new(),
+ device,
+ blas: Arc::new(blas),
+ curand: Arc::new(Mutex::new(CudaRng(curand))),
+ })
+ }
+}
+
impl BackendDevice for CudaDevice {
type Storage = CudaStorage;
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 91e56937..c4a8e936 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -130,6 +130,10 @@ impl Device {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}
+ pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
+ Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
+ }
+
pub fn new_metal(ordinal: usize) -> Result<Self> {
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
}
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index 68eef1ef..b4f2e8aa 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -14,6 +14,12 @@ macro_rules! fail {
};
}
+impl CudaDevice {
+ pub fn new_with_stream(_: usize) -> Result<Self> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+}
+
impl crate::backend::BackendStorage for CudaStorage {
type Device = CudaDevice;