diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-08 09:32:36 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-08 09:32:36 +0100 |
commit | 9abeddd750fe13632136a9807fcb0b6d1c999bd3 (patch) | |
tree | ee1da54238865501c0d66ffff381ceea8835b105 /candle-core | |
parent | 2e5fb0b2518aa7f7c666967fe4160462578cf8d0 (diff) | |
download | candle-9abeddd750fe13632136a9807fcb0b6d1c999bd3.tar.gz candle-9abeddd750fe13632136a9807fcb0b6d1c999bd3.tar.bz2 candle-9abeddd750fe13632136a9807fcb0b6d1c999bd3.zip |
Make the cuda rng seedable. (#1056)
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/src/backend.rs | 2 | ||||
-rw-r--r-- | candle-core/src/cpu_backend.rs | 4 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 6 | ||||
-rw-r--r-- | candle-core/src/dummy_cuda_backend.rs | 4 |
4 files changed, 16 insertions, 0 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 03a07434..7f0e2fc7 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -111,4 +111,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone { fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>; fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>; + + fn set_seed(&self, _: u64) -> Result<()>; } diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 4e808b34..86cbeb78 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -2603,6 +2603,10 @@ impl BackendDevice for CpuDevice { Ok(Self) } + fn set_seed(&self, _seed: u64) -> Result<()> { + crate::bail!("cannot seed the CPU rng with set_seed") + } + fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> { use rand::prelude::*; diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index f7518067..f0f48327 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -223,6 +223,12 @@ impl BackendDevice for CudaDevice { }) } + fn set_seed(&self, seed: u64) -> Result<()> { + let mut curand = self.curand.lock().unwrap(); + curand.0.set_seed(seed).w()?; + Ok(()) + } + fn location(&self) -> crate::DeviceLocation { crate::DeviceLocation::Cuda { gpu_id: self.device.ordinal(), diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 5cc9c6d8..53574458 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -167,6 +167,10 @@ impl crate::backend::BackendDevice for CudaDevice { Err(Error::NotCompiledWithCudaSupport) } + fn set_seed(&self, _: u64) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + fn location(&self) -> crate::DeviceLocation { fail!() } |