summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/src/backend.rs2
-rw-r--r--candle-core/src/cpu_backend.rs4
-rw-r--r--candle-core/src/cuda_backend.rs6
-rw-r--r--candle-core/src/dummy_cuda_backend.rs4
4 files changed, 16 insertions, 0 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index 03a07434..7f0e2fc7 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -111,4 +111,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
+
+ fn set_seed(&self, _: u64) -> Result<()>;
}
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 4e808b34..86cbeb78 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -2603,6 +2603,10 @@ impl BackendDevice for CpuDevice {
Ok(Self)
}
+ fn set_seed(&self, _seed: u64) -> Result<()> {
+ crate::bail!("cannot seed the CPU rng with set_seed")
+ }
+
fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
use rand::prelude::*;
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index f7518067..f0f48327 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -223,6 +223,12 @@ impl BackendDevice for CudaDevice {
})
}
+ fn set_seed(&self, seed: u64) -> Result<()> {
+ let mut curand = self.curand.lock().unwrap();
+ curand.0.set_seed(seed).w()?;
+ Ok(())
+ }
+
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda {
gpu_id: self.device.ordinal(),
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index 5cc9c6d8..53574458 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -167,6 +167,10 @@ impl crate::backend::BackendDevice for CudaDevice {
Err(Error::NotCompiledWithCudaSupport)
}
+ fn set_seed(&self, _: u64) -> Result<()> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
fn location(&self) -> crate::DeviceLocation {
fail!()
}