summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml1
-rw-r--r--candle-core/Cargo.toml1
-rw-r--r--candle-core/src/cpu_backend.rs30
-rw-r--r--candle-core/tests/tensor_tests.rs18
4 files changed, 40 insertions, 10 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 4bc0058b..850b13ef 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -41,6 +41,7 @@ memmap2 = "0.7.1"
num_cpus = "1.15.0"
num-traits = "0.2.15"
rand = "0.8.5"
+rand_distr = "0.4.3"
safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_json = "1.0.99"
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index af77a0e0..7411592e 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -22,6 +22,7 @@ memmap2 = { workspace = true }
num-traits = { workspace = true }
num_cpus = { workspace = true }
rand = { workspace = true }
+rand_distr = { workspace = true }
safetensors = { workspace = true }
thiserror = { workspace = true }
zip = { workspace = true }
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 238a9a69..a59a959a 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -2070,35 +2070,45 @@ impl BackendDevice for CpuDevice {
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
DType::BF16 => {
let mut data = Vec::with_capacity(elem_count);
- let std = bf16::from_f64(std);
- let mean = bf16::from_f64(mean);
+ let normal = match rand_distr::Normal::new(mean, std) {
+ Ok(n) => n,
+ Err(e) => Err(Error::wrap(e))?,
+ };
for _i in 0..elem_count {
- data.push(rng.sample::<bf16, _>(rand::distributions::Standard) * std + mean)
+ data.push(bf16::from_f64(normal.sample(&mut rng)))
}
Ok(CpuStorage::BF16(data))
}
DType::F16 => {
let mut data = Vec::with_capacity(elem_count);
- let std = f16::from_f64(std);
- let mean = f16::from_f64(mean);
+ let normal = match rand_distr::Normal::new(mean, std) {
+ Ok(n) => n,
+ Err(e) => Err(Error::wrap(e))?,
+ };
for _i in 0..elem_count {
- data.push(rng.sample::<f16, _>(rand::distributions::Standard) * std + mean)
+ data.push(f16::from_f64(normal.sample(&mut rng)))
}
Ok(CpuStorage::F16(data))
}
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
- let std = std as f32;
- let mean = mean as f32;
+ let normal = match rand_distr::Normal::new(mean, std) {
+ Ok(n) => n,
+ Err(e) => Err(Error::wrap(e))?,
+ };
for _i in 0..elem_count {
- data.push(rng.sample::<f32, _>(rand::distributions::Standard) * std + mean)
+ data.push(normal.sample(&mut rng) as f32)
}
Ok(CpuStorage::F32(data))
}
DType::F64 => {
let mut data = Vec::with_capacity(elem_count);
+ let normal = match rand_distr::Normal::new(mean, std) {
+ Ok(n) => n,
+ Err(e) => Err(Error::wrap(e))?,
+ };
for _i in 0..elem_count {
- data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean)
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F64(data))
}
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 599c2665..aec86482 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -9,6 +9,23 @@ fn zeros(device: &Device) -> Result<()> {
Ok(())
}
+fn randn_hasneg(device: &Device) -> Result<()> {
+ let s = 200;
+ let t = Tensor::randn(
+ 0f32,
+ 1f32, s
+ as usize,
+ &Device::Cpu
+ )?
+ .to_vec1::<f32>()?;
+ for i in t {
+ if i < 0. {
+ return Ok(())
+ }
+ }
+ panic!("randn failed to generate a negative number")
+}
+
fn add_mul(device: &Device) -> Result<()> {
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
let dim1 = tensor.dims1()?;
@@ -849,6 +866,7 @@ fn broadcasting(device: &Device) -> Result<()> {
}
test_device!(zeros, zeros_cpu, zeros_gpu);
+test_device!(randn_hasneg, randn_hasneg_cpu, randn_hasneg_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
test_device!(narrow, narrow_cpu, narrow_gpu);