summaryrefslogtreecommitdiff
path: root/candle-core/src/cpu_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r--candle-core/src/cpu_backend.rs55
1 files changed, 41 insertions, 14 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 8d38b158..83c7080f 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -369,8 +369,7 @@ pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
block_start_index,
block_len,
} => {
- let mut result = vec![];
- result.reserve(layout.shape().elem_count());
+ let mut result = Vec::with_capacity(layout.shape().elem_count());
// Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 {
for index in block_start_index {
@@ -1843,12 +1842,27 @@ impl BackendDevice for CpuDevice {
let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
match dtype {
- DType::U8 | DType::U32 | DType::BF16 | DType::F16 => {
- Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
+ DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()),
+ DType::BF16 => {
+ let mut data = Vec::with_capacity(elem_count);
+ let uniform =
+ rand::distributions::Uniform::new(bf16::from_f64(min), bf16::from_f64(max));
+ for _i in 0..elem_count {
+ data.push(rng.sample::<bf16, _>(uniform))
+ }
+ Ok(CpuStorage::BF16(data))
+ }
+ DType::F16 => {
+ let mut data = Vec::with_capacity(elem_count);
+ let uniform =
+ rand::distributions::Uniform::new(f16::from_f64(min), f16::from_f64(max));
+ for _i in 0..elem_count {
+ data.push(rng.sample::<f16, _>(uniform))
+ }
+ Ok(CpuStorage::F16(data))
}
DType::F32 => {
- let mut data = Vec::new();
- data.reserve(elem_count);
+ let mut data = Vec::with_capacity(elem_count);
let uniform = rand::distributions::Uniform::new(min as f32, max as f32);
for _i in 0..elem_count {
data.push(rng.sample::<f32, _>(uniform))
@@ -1856,8 +1870,7 @@ impl BackendDevice for CpuDevice {
Ok(CpuStorage::F32(data))
}
DType::F64 => {
- let mut data = Vec::new();
- data.reserve(elem_count);
+ let mut data = Vec::with_capacity(elem_count);
let uniform = rand::distributions::Uniform::new(min, max);
for _i in 0..elem_count {
data.push(rng.sample::<f64, _>(uniform))
@@ -1873,12 +1886,27 @@ impl BackendDevice for CpuDevice {
let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
match dtype {
- DType::U8 | DType::U32 | DType::BF16 | DType::F16 => {
- Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
+ DType::U8 | DType::U32 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
+ DType::BF16 => {
+ let mut data = Vec::with_capacity(elem_count);
+ let std = bf16::from_f64(std);
+ let mean = bf16::from_f64(mean);
+ for _i in 0..elem_count {
+ data.push(rng.sample::<bf16, _>(rand::distributions::Standard) * std + mean)
+ }
+ Ok(CpuStorage::BF16(data))
+ }
+ DType::F16 => {
+ let mut data = Vec::with_capacity(elem_count);
+ let std = f16::from_f64(std);
+ let mean = f16::from_f64(mean);
+ for _i in 0..elem_count {
+ data.push(rng.sample::<f16, _>(rand::distributions::Standard) * std + mean)
+ }
+ Ok(CpuStorage::F16(data))
}
DType::F32 => {
- let mut data = Vec::new();
- data.reserve(elem_count);
+ let mut data = Vec::with_capacity(elem_count);
let std = std as f32;
let mean = mean as f32;
for _i in 0..elem_count {
@@ -1887,8 +1915,7 @@ impl BackendDevice for CpuDevice {
Ok(CpuStorage::F32(data))
}
DType::F64 => {
- let mut data = Vec::new();
- data.reserve(elem_count);
+ let mut data = Vec::with_capacity(elem_count);
for _i in 0..elem_count {
data.push(rng.sample::<f64, _>(rand::distributions::Standard) * std + mean)
}