diff options
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r-- | candle-core/src/cpu_backend.rs | 55 |
1 files changed, 41 insertions, 14 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 8d38b158..83c7080f 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -369,8 +369,7 @@ pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>( block_start_index, block_len, } => { - let mut result = vec![]; - result.reserve(layout.shape().elem_count()); + let mut result = Vec::with_capacity(layout.shape().elem_count()); // Specialize the case where block_len is one to avoid the second loop. if block_len == 1 { for index in block_start_index { @@ -1843,12 +1842,27 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 | DType::BF16 | DType::F16 => { - Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) + DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()), + DType::BF16 => { + let mut data = Vec::with_capacity(elem_count); + let uniform = + rand::distributions::Uniform::new(bf16::from_f64(min), bf16::from_f64(max)); + for _i in 0..elem_count { + data.push(rng.sample::<bf16, _>(uniform)) + } + Ok(CpuStorage::BF16(data)) + } + DType::F16 => { + let mut data = Vec::with_capacity(elem_count); + let uniform = + rand::distributions::Uniform::new(f16::from_f64(min), f16::from_f64(max)); + for _i in 0..elem_count { + data.push(rng.sample::<f16, _>(uniform)) + } + Ok(CpuStorage::F16(data)) } DType::F32 => { - let mut data = Vec::new(); - data.reserve(elem_count); + let mut data = Vec::with_capacity(elem_count); let uniform = rand::distributions::Uniform::new(min as f32, max as f32); for _i in 0..elem_count { data.push(rng.sample::<f32, _>(uniform)) @@ -1856,8 +1870,7 @@ impl BackendDevice for CpuDevice { Ok(CpuStorage::F32(data)) } DType::F64 => { - let mut data = Vec::new(); - data.reserve(elem_count); + let mut data = Vec::with_capacity(elem_count); let uniform = rand::distributions::Uniform::new(min, max); for _i in 0..elem_count { data.push(rng.sample::<f64, _>(uniform)) @@ -1873,12 +1886,27 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 | DType::BF16 | DType::F16 => { - Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) + DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()), + DType::BF16 => { + let mut data = Vec::with_capacity(elem_count); + let std = bf16::from_f64(std); + let mean = bf16::from_f64(mean); + for _i in 0..elem_count { + data.push(rng.sample::<bf16, _>(rand::distributions::Standard) * std + mean) + } + Ok(CpuStorage::BF16(data)) + } + DType::F16 => { + let mut data = Vec::with_capacity(elem_count); + let std = f16::from_f64(std); + let mean = f16::from_f64(mean); + for _i in 0..elem_count { + data.push(rng.sample::<f16, _>(rand::distributions::Standard) * std + mean) + } + Ok(CpuStorage::F16(data)) } DType::F32 => { - let mut data = Vec::new(); - data.reserve(elem_count); + let mut data = Vec::with_capacity(elem_count); let std = std as f32; let mean = mean as f32; for _i in 0..elem_count { @@ -1887,8 +1915,7 @@ impl BackendDevice for CpuDevice { Ok(CpuStorage::F32(data)) } DType::F64 => { - let mut data = Vec::new(); - data.reserve(elem_count); + let mut data = Vec::with_capacity(elem_count); for _i in 0..elem_count { data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean) } |