diff options
-rw-r--r-- | candle-core/src/cuda_backend.rs | 4 | ||||
-rw-r--r-- | candle-core/src/device.rs | 7 |
2 files changed, 10 insertions, 1 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index f0f48327..f1c35ae1 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -224,8 +224,10 @@ impl BackendDevice for CudaDevice { } fn set_seed(&self, seed: u64) -> Result<()> { + // We do not call set_seed but instead create a new curand object. This ensures that the + // state will be identical and the same random numbers will be generated. let mut curand = self.curand.lock().unwrap(); - curand.0.set_seed(seed).w()?; + curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?; Ok(()) } diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 0ed23a18..d566ba42 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -128,6 +128,13 @@ impl Device { Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) } + pub fn set_seed(&self, seed: u64) -> Result<()> { + match self { + Self::Cpu => crate::cpu_backend::CpuDevice.set_seed(seed), + Self::Cuda(c) => c.set_seed(seed), + } + } + pub fn same_device(&self, rhs: &Self) -> bool { match (self, rhs) { (Self::Cpu, Self::Cpu) => true, |