diff options
-rw-r--r-- | Cargo.toml | 1 | ||||
-rw-r--r-- | candle-core/Cargo.toml | 1 | ||||
-rw-r--r-- | candle-core/src/cpu_backend.rs | 30 | ||||
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 18 |
4 files changed, 40 insertions, 10 deletions
@@ -41,6 +41,7 @@ memmap2 = "0.7.1" num_cpus = "1.15.0" num-traits = "0.2.15" rand = "0.8.5" +rand_distr = "0.4.3" safetensors = "0.3.1" serde = { version = "1.0.171", features = ["derive"] } serde_json = "1.0.99" diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index af77a0e0..7411592e 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -22,6 +22,7 @@ memmap2 = { workspace = true } num-traits = { workspace = true } num_cpus = { workspace = true } rand = { workspace = true } +rand_distr = { workspace = true } safetensors = { workspace = true } thiserror = { workspace = true } zip = { workspace = true } diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 238a9a69..a59a959a 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -2070,35 +2070,45 @@ impl BackendDevice for CpuDevice { DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()), DType::BF16 => { let mut data = Vec::with_capacity(elem_count); - let std = bf16::from_f64(std); - let mean = bf16::from_f64(mean); + let normal = match rand_distr::Normal::new(mean, std) { + Ok(n) => n, + Err(e) => Err(Error::wrap(e))?, + }; for _i in 0..elem_count { - data.push(rng.sample::<bf16, _>(rand::distributions::Standard) * std + mean) + data.push(bf16::from_f64(normal.sample(&mut rng))) } Ok(CpuStorage::BF16(data)) } DType::F16 => { let mut data = Vec::with_capacity(elem_count); - let std = f16::from_f64(std); - let mean = f16::from_f64(mean); + let normal = match rand_distr::Normal::new(mean, std) { + Ok(n) => n, + Err(e) => Err(Error::wrap(e))?, + }; for _i in 0..elem_count { - data.push(rng.sample::<f16, _>(rand::distributions::Standard) * std + mean) + data.push(f16::from_f64(normal.sample(&mut rng))) } Ok(CpuStorage::F16(data)) } DType::F32 => { let mut data = Vec::with_capacity(elem_count); - let std = std as f32; - let mean = mean as f32; + let normal = match rand_distr::Normal::new(mean, std) { + Ok(n) => n, + Err(e) => Err(Error::wrap(e))?, + }; for _i in 0..elem_count { - data.push(rng.sample::<f32, _>(rand::distributions::Standard) * std + mean) + data.push(normal.sample(&mut rng) as f32) } Ok(CpuStorage::F32(data)) } DType::F64 => { let mut data = Vec::with_capacity(elem_count); + let normal = match rand_distr::Normal::new(mean, std) { + Ok(n) => n, + Err(e) => Err(Error::wrap(e))?, + }; for _i in 0..elem_count { - data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean) + data.push(normal.sample(&mut rng)) } Ok(CpuStorage::F64(data)) } diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 599c2665..aec86482 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -9,6 +9,23 @@ fn zeros(device: &Device) -> Result<()> { Ok(()) } +fn randn_hasneg(device: &Device) -> Result<()> { + let s = 200; + let t = Tensor::randn( + 0f32, + 1f32, s + as usize, + &Device::Cpu + )? + .to_vec1::<f32>()?; + for i in t { + if i < 0. { + return Ok(()) + } + } + panic!("randn failed to generate a negative number") +} + fn add_mul(device: &Device) -> Result<()> { let tensor = Tensor::new(&[3f32, 1., 4.], device)?; let dim1 = tensor.dims1()?; @@ -849,6 +866,7 @@ fn broadcasting(device: &Device) -> Result<()> { } test_device!(zeros, zeros_cpu, zeros_gpu); +test_device!(randn_hasneg, randn_hasneg_cpu, randn_hasneg_gpu); test_device!(add_mul, add_mul_cpu, add_mul_gpu); test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu); test_device!(narrow, narrow_cpu, narrow_gpu); |