diff options
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r-- | candle-core/src/cpu_backend.rs | 78 |
1 files changed, 70 insertions, 8 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index c336dfef..91ccd972 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -148,6 +148,48 @@ fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut } } +fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>( + vs: &[T], + layout: &Layout, + mut f: F, + mut f_vec: FV, +) -> Vec<U> { + match layout.strided_blocks() { + crate::StridedBlocks::SingleBlock { start_offset, len } => { + let mut ys: Vec<U> = Vec::with_capacity(len); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) }; + f_vec(&vs[start_offset..start_offset + len], ys_to_set); + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(len) }; + ys + } + crate::StridedBlocks::MultipleBlocks { + block_start_index, + block_len, + } => { + let mut result = vec![]; + result.reserve(layout.shape().elem_count()); + // Specialize the case where block_len is one to avoid the second loop. + if block_len == 1 { + for index in block_start_index { + let v = unsafe { vs.get_unchecked(index) }; + result.push(f(*v)) + } + } else { + // TODO: Use f_vec here. + for index in block_start_index { + for offset in 0..block_len { + let v = unsafe { vs.get_unchecked(index + offset) }; + result.push(f(*v)) + } + } + } + result + } + } +} + // This function maps over two strided index sequences. fn binary_map<T: Copy, F: FnMut(T, T) -> T>( lhs_l: &Layout, @@ -864,20 +906,40 @@ impl BackendStorage for CpuStorage { fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> { match self { Self::BF16(storage) => { - let data = unary_map(storage, layout, B::bf16); - Ok(Self::BF16(data)) + if B::BF16_VEC { + let data = unary_map_vec(storage, layout, B::bf16, B::bf16_vec); + Ok(Self::BF16(data)) + } else { + let data = unary_map(storage, layout, B::bf16); + Ok(Self::BF16(data)) + } } Self::F16(storage) => { - let data = unary_map(storage, layout, B::f16); - Ok(Self::F16(data)) + if B::F16_VEC { + let data = unary_map_vec(storage, layout, B::f16, B::f16_vec); + Ok(Self::F16(data)) + } else { + let data = unary_map(storage, layout, B::f16); + Ok(Self::F16(data)) + } } Self::F32(storage) => { - let data = unary_map(storage, layout, B::f32); - Ok(Self::F32(data)) + if B::F32_VEC { + let data = unary_map_vec(storage, layout, B::f32, B::f32_vec); + Ok(Self::F32(data)) + } else { + let data = unary_map(storage, layout, B::f32); + Ok(Self::F32(data)) + } } Self::F64(storage) => { - let data = unary_map(storage, layout, B::f64); - Ok(Self::F64(data)) + if B::F64_VEC { + let data = unary_map_vec(storage, layout, B::f64, B::f64_vec); + Ok(Self::F64(data)) + } else { + let data = unary_map(storage, layout, B::f64); + Ok(Self::F64(data)) + } } Self::U8(storage) => { let data = unary_map(storage, layout, B::u8); |