diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-10-02 21:30:58 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-02 21:30:58 +0200 |
commit | 7b60bda4ed8c9d861396fe74307d6c77281522ef (patch) | |
tree | 5fd41d30db814c61c74acd9a1546dd039f7c183a /candle-core | |
parent | 936300678d588c6525594ef2578737e0c19ecf07 (diff) | |
download | candle-7b60bda4ed8c9d861396fe74307d6c77281522ef.tar.gz candle-7b60bda4ed8c9d861396fe74307d6c77281522ef.tar.bz2 candle-7b60bda4ed8c9d861396fe74307d6c77281522ef.zip |
Add support for cuda streams. (#2532)
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/src/cuda_backend/device.rs | 14 | ||||
-rw-r--r-- | candle-core/src/device.rs | 4 | ||||
-rw-r--r-- | candle-core/src/dummy_cuda_backend.rs | 6 |
3 files changed, 24 insertions, 0 deletions
diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index 0aa58cac..89fe44a6 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -144,6 +144,20 @@ impl CudaDevice { } } +impl CudaDevice { + pub fn new_with_stream(ordinal: usize) -> Result<Self> { + let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?; + let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?; + let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?; + Ok(Self { + id: DeviceId::new(), + device, + blas: Arc::new(blas), + curand: Arc::new(Mutex::new(CudaRng(curand))), + }) + } +} + impl BackendDevice for CudaDevice { type Storage = CudaStorage; diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 91e56937..c4a8e936 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -130,6 +130,10 @@ impl Device { Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) } + pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> { + Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?)) + } + pub fn new_metal(ordinal: usize) -> Result<Self> { Ok(Self::Metal(crate::MetalDevice::new(ordinal)?)) } diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 68eef1ef..b4f2e8aa 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -14,6 +14,12 @@ macro_rules! fail { }; } +impl CudaDevice { + pub fn new_with_stream(_: usize) -> Result<Self> { + Err(Error::NotCompiledWithCudaSupport) + } +} + impl crate::backend::BackendStorage for CudaStorage { type Device = CudaDevice; |