diff options
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r-- | candle-core/src/cpu_backend.rs | 265 |
1 files changed, 251 insertions, 14 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index ed3dd3fc..4e808b34 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -2,6 +2,10 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; use half::{bf16, f16}; +use rayon::prelude::*; + +const USE_IM2COL_CONV1D: bool = true; +const USE_IM2COL_CONV2D: bool = true; // TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + // intercept the oom errors to avoid panicking and provide a proper error. @@ -445,7 +449,7 @@ pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U } // This function maps over two strided index sequences. -fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>( +pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>( lhs_l: &Layout, rhs_l: &Layout, lhs: &[T], @@ -525,7 +529,7 @@ fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>( } // Similar to binary_map but with vectorized variants. -fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>( +pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>( lhs_l: &Layout, rhs_l: &Layout, lhs: &[T], @@ -723,6 +727,36 @@ impl Map1 for MaxPool2D { } } +struct UpsampleNearest1D(usize); + +impl Map1 for UpsampleNearest1D { + fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> { + // TODO: Specialized implementation for the case 2*sz? + let dst_sz = self.0; + let (b_sz, c, src_sz) = layout.shape().dims3()?; + let stride = layout.stride(); + let stride_sz = stride[2]; + let src_index = layout.start_offset(); + let scale_sz = src_sz as f64 / dst_sz as f64; + let mut dst = vec![T::zero(); b_sz * c * dst_sz]; + let src_idxs = (0..dst_sz) + .map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize)) + .collect::<Vec<_>>(); + for b_idx in 0..b_sz { + let dst = &mut dst[b_idx * c * dst_sz..]; + let src_index = src_index + b_idx * stride[0]; + for c_idx in 0..c { + let dst = &mut dst[c_idx * dst_sz..]; + let src_index = src_index + c_idx * stride[1]; + for (idx, src_idx) in src_idxs.iter().enumerate() { + dst[idx] = src[src_index + src_idx * stride_sz] + } + } + } + Ok(dst) + } +} + struct UpsampleNearest2D(usize, usize); impl Map1 for UpsampleNearest2D { @@ -1052,10 +1086,8 @@ impl<'a> Map2 for Conv1D<'a> { } } - let num_threads = crate::utils::get_num_threads(); - for offset in 0..p.k_size { - crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| { + (0..p.c_out).into_par_iter().for_each(|dst_c_idx| { let dst_idx = dst_c_idx * l_out; let k_cont = (0..p.c_in) .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2]) @@ -1090,6 +1122,140 @@ impl<'a> Map2 for Conv1D<'a> { } } +struct Im2Col1D { + l_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col1D { + fn l_out(&self, l: usize) -> usize { + (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1 + } +} + +impl Map1 for Im2Col1D { + fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> { + let &Self { + l_k, + stride, + dilation, + padding, + } = self; + let (b, c, l) = layout.shape().dims3()?; + let l_out = self.l_out(l); + let src = &vs[layout.start_offset()..]; + let mut dst = vec![T::zero(); b * l_out * c * l_k]; + let (src_s0, src_s1, src_s2) = { + let s = layout.stride(); + (s[0], s[1], s[2]) + }; + // TODO: provide specialized kernels for the common use cases. + // - l_k = 1 + // - padding = 0 + // - stride = 1 + // - dilation = 1 + for b_idx in 0..b { + let src_idx = b_idx * src_s0; + let dst_idx = b_idx * l_out * c * l_k; + for l_idx in 0..l_out { + let dst_idx = dst_idx + l_idx * c * l_k; + for c_idx in 0..c { + let dst_idx = dst_idx + c_idx * l_k; + let src_idx = c_idx * src_s1 + src_idx; + for l_k_idx in 0..l_k { + let src_l = l_idx * stride + l_k_idx * dilation; + if padding != 0 && (src_l < padding || src_l >= l + padding) { + continue; + } + let src_l = src_l - padding; + let src_idx = src_idx + src_l * src_s2; + let dst_idx = dst_idx + l_k_idx; + dst[dst_idx] = src[src_idx] + } + } + } + } + Ok(dst) + } +} + +struct Im2Col { + h_k: usize, + w_k: usize, + stride: usize, + dilation: usize, + padding: usize, +} + +impl Im2Col { + fn hw_out(&self, h: usize, w: usize) -> (usize, usize) { + let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1; + let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1; + (h_out, w_out) + } +} + +impl Map1 for Im2Col { + fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> { + let &Self { + h_k, + w_k, + stride, + dilation, + padding, + } = self; + let (b, c, h, w) = layout.shape().dims4()?; + let (h_out, w_out) = self.hw_out(h, w); + let src = &vs[layout.start_offset()..]; + let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k]; + let (src_s0, src_s1, src_s2, src_s3) = { + let s = layout.stride(); + (s[0], s[1], s[2], s[3]) + }; + // TODO: provide specialized kernels for the common use cases. + // - h_k = w_k = 1 + // - padding = 0 + // - stride = 1 + // - dilation = 1 + for b_idx in 0..b { + let src_idx = b_idx * src_s0; + let dst_idx = b_idx * h_out * w_out * c * h_k * w_k; + for h_idx in 0..h_out { + let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k; + for w_idx in 0..w_out { + let dst_idx = dst_idx + w_idx * c * h_k * w_k; + for c_idx in 0..c { + let dst_idx = dst_idx + c_idx * h_k * w_k; + let src_idx = c_idx * src_s1 + src_idx; + for h_k_idx in 0..h_k { + let src_h = h_idx * stride + h_k_idx * dilation; + if padding != 0 && (src_h < padding || src_h >= h + padding) { + continue; + } + let src_h = src_h - padding; + let src_idx = src_idx + src_h * src_s2; + let dst_idx = dst_idx + h_k_idx * w_k; + for w_k_idx in 0..w_k { + let src_w = w_idx * stride + w_k_idx * dilation; + if padding != 0 && (src_w < padding || src_w >= w + padding) { + continue; + } + let src_w = src_w - padding; + let src_idx = src_idx + src_w * src_s3; + let dst_idx = dst_idx + w_k_idx; + dst[dst_idx] = src[src_idx] + } + } + } + } + } + } + Ok(dst) + } +} + struct Conv2D<'a>(&'a crate::conv::ParamsConv2D); impl<'a> Map2 for Conv2D<'a> { @@ -1123,11 +1289,9 @@ impl<'a> Map2 for Conv2D<'a> { } } - let num_threads = crate::utils::get_num_threads(); - for offset_h in 0..p.k_h { for offset_w in 0..p.k_w { - crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| { + (0..p.c_out).into_par_iter().for_each(|dst_c_idx| { let dst_idx = dst_c_idx * out_w * out_h; let k_cont = (0..p.c_in) .map(|c_in_idx| { @@ -1216,11 +1380,10 @@ impl<'a> Map2 for ConvTranspose2D<'a> { } } } - let num_threads = crate::utils::get_num_threads(); for k_y in 0..p.k_h { for k_x in 0..p.k_w { - crate::cpu::kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| { + (0..p.c_out).into_par_iter().for_each(|dst_c_idx| { let k_cont = (0..p.c_in) .map(|c_in_idx| { k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3] @@ -1298,8 +1461,9 @@ impl Map2 for MatMul { ) -> Result<Vec<T>> { use gemm::{gemm, Parallelism}; - if T::DTYPE == DType::BF16 { - return Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?; + match T::DTYPE { + DType::F16 | DType::F32 | DType::F64 => {} + _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?, } let (b, m, n, k) = self.0; @@ -2003,6 +2167,10 @@ impl BackendStorage for CpuStorage { MaxPool2D(kernel_size, stride).map(self, layout) } + fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> { + UpsampleNearest1D(sz).map(self, layout) + } + fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> { UpsampleNearest2D(h, w).map(self, layout) } @@ -2231,7 +2399,40 @@ impl BackendStorage for CpuStorage { kernel_l: &Layout, params: &crate::conv::ParamsConv1D, ) -> Result<Self> { - Conv1D(params).map(self, l, kernel, kernel_l) + if !USE_IM2COL_CONV1D { + return Conv1D(params).map(self, l, kernel, kernel_l); + } + let op = Im2Col1D { + l_k: params.k_size, + padding: params.padding, + stride: params.stride, + dilation: params.dilation, + }; + let col = op.map(self, l)?; + let b = params.b_size; + let n = params.c_out; + let l_out = params.l_out(); + let k = op.l_k * params.c_in; + let m = l_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?; + let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) } fn conv2d( @@ -2241,7 +2442,43 @@ impl BackendStorage for CpuStorage { kernel_l: &Layout, params: &crate::conv::ParamsConv2D, ) -> Result<Self> { - Conv2D(params).map(self, l, kernel, kernel_l) + if !USE_IM2COL_CONV2D { + return Conv2D(params).map(self, l, kernel, kernel_l); + } + let op = Im2Col { + h_k: params.k_h, + w_k: params.k_w, + padding: params.padding, + stride: params.stride, + dilation: params.dilation, + }; + let col = op.map(self, l)?; + let b = params.b_size; + let n = params.c_out; + let (h_out, w_out) = (params.out_h(), params.out_w()); + let k = op.h_k * op.w_k * params.c_in; + let m = h_out * w_out; + let col_l = Layout::contiguous((b, m, k)); + let res = if kernel_l.is_contiguous() { + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + } else { + // Make the kernel contiguous if not already the case. + let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?; + kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?; + let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset()) + .transpose(1, 2)? + .broadcast_as((b, k, n))?; + col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)? + }; + let res_l = Layout::contiguous((b, h_out, w_out, params.c_out)) + .transpose(1, 2)? + .transpose(1, 3)?; + let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?; + res.copy_strided_src(&mut res_t, 0, &res_l)?; + Ok(res_t) } fn conv_transpose2d( |