diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-27 07:40:36 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-27 07:40:36 +0100 |
commit | 6475bfadfebcd02dd58adc60452f492f8dc11a39 (patch) | |
tree | 60e7ba3b02b1ebd26854f4e6adfec76df0617f48 /candle-core/src/cpu_backend.rs | |
parent | 89ba005962495f2bfbda286e185e9c3c7f5300a3 (diff) | |
download | candle-6475bfadfebcd02dd58adc60452f492f8dc11a39.tar.gz candle-6475bfadfebcd02dd58adc60452f492f8dc11a39.tar.bz2 candle-6475bfadfebcd02dd58adc60452f492f8dc11a39.zip |
Simplify Tensor::randn. (#255)
* Simplify Tensor::randn.
* Also switch Tensor::rand to use a generic dtype.
* Support sampling for f16.
* Cleanup.
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r-- | candle-core/src/cpu_backend.rs | 55 |
1 files changed, 41 insertions, 14 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 8d38b158..83c7080f 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -369,8 +369,7 @@ pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>( block_start_index, block_len, } => { - let mut result = vec![]; - result.reserve(layout.shape().elem_count()); + let mut result = Vec::with_capacity(layout.shape().elem_count()); // Specialize the case where block_len is one to avoid the second loop. if block_len == 1 { for index in block_start_index { @@ -1843,12 +1842,27 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 | DType::BF16 | DType::F16 => { - Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) + DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()), + DType::BF16 => { + let mut data = Vec::with_capacity(elem_count); + let uniform = + rand::distributions::Uniform::new(bf16::from_f64(min), bf16::from_f64(max)); + for _i in 0..elem_count { + data.push(rng.sample::<bf16, _>(uniform)) + } + Ok(CpuStorage::BF16(data)) + } + DType::F16 => { + let mut data = Vec::with_capacity(elem_count); + let uniform = + rand::distributions::Uniform::new(f16::from_f64(min), f16::from_f64(max)); + for _i in 0..elem_count { + data.push(rng.sample::<f16, _>(uniform)) + } + Ok(CpuStorage::F16(data)) } DType::F32 => { - let mut data = Vec::new(); - data.reserve(elem_count); + let mut data = Vec::with_capacity(elem_count); let uniform = rand::distributions::Uniform::new(min as f32, max as f32); for _i in 0..elem_count { data.push(rng.sample::<f32, _>(uniform)) @@ -1856,8 +1870,7 @@ impl BackendDevice for CpuDevice { Ok(CpuStorage::F32(data)) } DType::F64 => { - let mut data = Vec::new(); - data.reserve(elem_count); + let mut data = Vec::with_capacity(elem_count); let uniform = rand::distributions::Uniform::new(min, max); for _i in 0..elem_count { data.push(rng.sample::<f64, _>(uniform)) @@ -1873,12 +1886,27 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 | DType::BF16 | DType::F16 => { - Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) + DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()), + DType::BF16 => { + let mut data = Vec::with_capacity(elem_count); + let std = bf16::from_f64(std); + let mean = bf16::from_f64(mean); + for _i in 0..elem_count { + data.push(rng.sample::<bf16, _>(rand::distributions::Standard) * std + mean) + } + Ok(CpuStorage::BF16(data)) + } + DType::F16 => { + let mut data = Vec::with_capacity(elem_count); + let std = f16::from_f64(std); + let mean = f16::from_f64(mean); + for _i in 0..elem_count { + data.push(rng.sample::<f16, _>(rand::distributions::Standard) * std + mean) + } + Ok(CpuStorage::F16(data)) } DType::F32 => { - let mut data = Vec::new(); - data.reserve(elem_count); + let mut data = Vec::with_capacity(elem_count); let std = std as f32; let mean = mean as f32; for _i in 0..elem_count { @@ -1887,8 +1915,7 @@ impl BackendDevice for CpuDevice { Ok(CpuStorage::F32(data)) } DType::F64 => { - let mut data = Vec::new(); - data.reserve(elem_count); + let mut data = Vec::with_capacity(elem_count); for _i in 0..elem_count { data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean) } |