summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cpu_backend.rs29
-rw-r--r--candle-core/tests/tensor_tests.rs29
2 files changed, 21 insertions, 37 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index a59a959a..d4f5fcdc 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -2070,43 +2070,34 @@ impl BackendDevice for CpuDevice {
DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
DType::BF16 => {
let mut data = Vec::with_capacity(elem_count);
- let normal = match rand_distr::Normal::new(mean, std) {
- Ok(n) => n,
- Err(e) => Err(Error::wrap(e))?,
- };
+ let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
+ .map_err(Error::wrap)?;
for _i in 0..elem_count {
- data.push(bf16::from_f64(normal.sample(&mut rng)))
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::BF16(data))
}
DType::F16 => {
let mut data = Vec::with_capacity(elem_count);
- let normal = match rand_distr::Normal::new(mean, std) {
- Ok(n) => n,
- Err(e) => Err(Error::wrap(e))?,
- };
+ let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
+ .map_err(Error::wrap)?;
for _i in 0..elem_count {
- data.push(f16::from_f64(normal.sample(&mut rng)))
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F16(data))
}
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
- let normal = match rand_distr::Normal::new(mean, std) {
- Ok(n) => n,
- Err(e) => Err(Error::wrap(e))?,
- };
+ let normal =
+ rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
for _i in 0..elem_count {
- data.push(normal.sample(&mut rng) as f32)
+ data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F32(data))
}
DType::F64 => {
let mut data = Vec::with_capacity(elem_count);
- let normal = match rand_distr::Normal::new(mean, std) {
- Ok(n) => n,
- Err(e) => Err(Error::wrap(e))?,
- };
+ let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(normal.sample(&mut rng))
}
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index aec86482..0b77f1a5 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -9,23 +9,6 @@ fn zeros(device: &Device) -> Result<()> {
Ok(())
}
-fn randn_hasneg(device: &Device) -> Result<()> {
- let s = 200;
- let t = Tensor::randn(
- 0f32,
- 1f32, s
- as usize,
- &Device::Cpu
- )?
- .to_vec1::<f32>()?;
- for i in t {
- if i < 0. {
- return Ok(())
- }
- }
- panic!("randn failed to generate a negative number")
-}
-
fn add_mul(device: &Device) -> Result<()> {
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
let dim1 = tensor.dims1()?;
@@ -866,7 +849,6 @@ fn broadcasting(device: &Device) -> Result<()> {
}
test_device!(zeros, zeros_cpu, zeros_gpu);
-test_device!(randn_hasneg, randn_hasneg_cpu, randn_hasneg_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
test_device!(narrow, narrow_cpu, narrow_gpu);
@@ -887,3 +869,14 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
+
+// There was originally a bug on the CPU implementation for randn
+// https://github.com/huggingface/candle/issues/381
+#[test]
+fn randn_hasneg() -> Result<()> {
+ let t = Tensor::randn(0f32, 1f32, 200, &Device::Cpu)?.to_vec1::<f32>()?;
+ if t.iter().all(|&v| v >= 0.) {
+ candle_core::bail!("all values in tensors are non-negative")
+ }
+ Ok(())
+}