summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-14 10:03:59 +0100
committerGitHub <noreply@github.com>2023-10-14 10:03:59 +0100
commit9309cfc47d3a73605cc6dea8669bbea5b0a5784c (patch)
tree8a216730b5ad49da0fdb0a8ccde784e71fae534b /candle-core/src/cuda_backend.rs
parenta193bf5f603ebee8319b0f75a93edede4de9b7b9 (diff)
downloadcandle-9309cfc47d3a73605cc6dea8669bbea5b0a5784c.tar.gz
candle-9309cfc47d3a73605cc6dea8669bbea5b0a5784c.tar.bz2
candle-9309cfc47d3a73605cc6dea8669bbea5b0a5784c.zip
Create a new curand instead of reseeding. (#1089)
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r--candle-core/src/cuda_backend.rs4
1 files changed, 3 insertions, 1 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index f0f48327..f1c35ae1 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -224,8 +224,10 @@ impl BackendDevice for CudaDevice {
}
fn set_seed(&self, seed: u64) -> Result<()> {
+ // We do not call set_seed but instead create a new curand object. This ensures that the
+ // state will be identical and the same random numbers will be generated.
let mut curand = self.curand.lock().unwrap();
- curand.0.set_seed(seed).w()?;
+ curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
Ok(())
}