summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cuda_backend.rs4
-rw-r--r--candle-core/src/device.rs7
2 files changed, 10 insertions, 1 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index f0f48327..f1c35ae1 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -224,8 +224,10 @@ impl BackendDevice for CudaDevice {
}
fn set_seed(&self, seed: u64) -> Result<()> {
+ // We do not call set_seed but instead create a new curand object. This ensures that the
+ // state will be identical and the same random numbers will be generated.
let mut curand = self.curand.lock().unwrap();
- curand.0.set_seed(seed).w()?;
+ curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
Ok(())
}
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index 0ed23a18..d566ba42 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -128,6 +128,13 @@ impl Device {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}
+ pub fn set_seed(&self, seed: u64) -> Result<()> {
+ match self {
+ Self::Cpu => crate::cpu_backend::CpuDevice.set_seed(seed),
+ Self::Cuda(c) => c.set_seed(seed),
+ }
+ }
+
pub fn same_device(&self, rhs: &Self) -> bool {
match (self, rhs) {
(Self::Cpu, Self::Cpu) => true,