summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-08 09:32:36 +0100
committerGitHub <noreply@github.com>2023-10-08 09:32:36 +0100
commit9abeddd750fe13632136a9807fcb0b6d1c999bd3 (patch)
treeee1da54238865501c0d66ffff381ceea8835b105 /candle-core
parent2e5fb0b2518aa7f7c666967fe4160462578cf8d0 (diff)
downloadcandle-9abeddd750fe13632136a9807fcb0b6d1c999bd3.tar.gz
candle-9abeddd750fe13632136a9807fcb0b6d1c999bd3.tar.bz2
candle-9abeddd750fe13632136a9807fcb0b6d1c999bd3.zip
Make the cuda rng seedable. (#1056)
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/src/backend.rs2
-rw-r--r--candle-core/src/cpu_backend.rs4
-rw-r--r--candle-core/src/cuda_backend.rs6
-rw-r--r--candle-core/src/dummy_cuda_backend.rs4
4 files changed, 16 insertions, 0 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index 03a07434..7f0e2fc7 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -111,4 +111,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
+
+ fn set_seed(&self, _: u64) -> Result<()>;
}
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 4e808b34..86cbeb78 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -2603,6 +2603,10 @@ impl BackendDevice for CpuDevice {
Ok(Self)
}
+ fn set_seed(&self, _seed: u64) -> Result<()> {
+ crate::bail!("cannot seed the CPU rng with set_seed")
+ }
+
fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
use rand::prelude::*;
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index f7518067..f0f48327 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -223,6 +223,12 @@ impl BackendDevice for CudaDevice {
})
}
+ fn set_seed(&self, seed: u64) -> Result<()> {
+ let mut curand = self.curand.lock().unwrap();
+ curand.0.set_seed(seed).w()?;
+ Ok(())
+ }
+
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda {
gpu_id: self.device.ordinal(),
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index 5cc9c6d8..53574458 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -167,6 +167,10 @@ impl crate::backend::BackendDevice for CudaDevice {
Err(Error::NotCompiledWithCudaSupport)
}
+ fn set_seed(&self, _: u64) -> Result<()> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
fn location(&self) -> crate::DeviceLocation {
fail!()
}