summaryrefslogtreecommitdiff
path: root/candle-core/src/cpu_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r--candle-core/src/cpu_backend.rs78
1 files changed, 70 insertions, 8 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index c336dfef..91ccd972 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -148,6 +148,48 @@ fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut
}
}
+fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
+ vs: &[T],
+ layout: &Layout,
+ mut f: F,
+ mut f_vec: FV,
+) -> Vec<U> {
+ match layout.strided_blocks() {
+ crate::StridedBlocks::SingleBlock { start_offset, len } => {
+ let mut ys: Vec<U> = Vec::with_capacity(len);
+ let ys_to_set = ys.spare_capacity_mut();
+ let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
+ f_vec(&vs[start_offset..start_offset + len], ys_to_set);
+ // SAFETY: values are all set by f_vec.
+ unsafe { ys.set_len(len) };
+ ys
+ }
+ crate::StridedBlocks::MultipleBlocks {
+ block_start_index,
+ block_len,
+ } => {
+ let mut result = vec![];
+ result.reserve(layout.shape().elem_count());
+ // Specialize the case where block_len is one to avoid the second loop.
+ if block_len == 1 {
+ for index in block_start_index {
+ let v = unsafe { vs.get_unchecked(index) };
+ result.push(f(*v))
+ }
+ } else {
+ // TODO: Use f_vec here.
+ for index in block_start_index {
+ for offset in 0..block_len {
+ let v = unsafe { vs.get_unchecked(index + offset) };
+ result.push(f(*v))
+ }
+ }
+ }
+ result
+ }
+ }
+}
+
// This function maps over two strided index sequences.
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
lhs_l: &Layout,
@@ -864,20 +906,40 @@ impl BackendStorage for CpuStorage {
fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
match self {
Self::BF16(storage) => {
- let data = unary_map(storage, layout, B::bf16);
- Ok(Self::BF16(data))
+ if B::BF16_VEC {
+ let data = unary_map_vec(storage, layout, B::bf16, B::bf16_vec);
+ Ok(Self::BF16(data))
+ } else {
+ let data = unary_map(storage, layout, B::bf16);
+ Ok(Self::BF16(data))
+ }
}
Self::F16(storage) => {
- let data = unary_map(storage, layout, B::f16);
- Ok(Self::F16(data))
+ if B::F16_VEC {
+ let data = unary_map_vec(storage, layout, B::f16, B::f16_vec);
+ Ok(Self::F16(data))
+ } else {
+ let data = unary_map(storage, layout, B::f16);
+ Ok(Self::F16(data))
+ }
}
Self::F32(storage) => {
- let data = unary_map(storage, layout, B::f32);
- Ok(Self::F32(data))
+ if B::F32_VEC {
+ let data = unary_map_vec(storage, layout, B::f32, B::f32_vec);
+ Ok(Self::F32(data))
+ } else {
+ let data = unary_map(storage, layout, B::f32);
+ Ok(Self::F32(data))
+ }
}
Self::F64(storage) => {
- let data = unary_map(storage, layout, B::f64);
- Ok(Self::F64(data))
+ if B::F64_VEC {
+ let data = unary_map_vec(storage, layout, B::f64, B::f64_vec);
+ Ok(Self::F64(data))
+ } else {
+ let data = unary_map(storage, layout, B::f64);
+ Ok(Self::F64(data))
+ }
}
Self::U8(storage) => {
let data = unary_map(storage, layout, B::u8);