diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/cpu_backend.rs | 45 | ||||
-rw-r--r-- | candle-core/src/cpu_kernels.rs | 28 | ||||
-rw-r--r-- | candle-core/src/dtype.rs | 8 | ||||
-rw-r--r-- | candle-core/src/lib.rs | 1 |
4 files changed, 60 insertions, 22 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index d4f5fcdc..250e2721 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1023,14 +1023,7 @@ struct Conv1D<'a>(&'a crate::conv::ParamsConv1D); impl<'a> Map2 for Conv1D<'a> { const OP: &'static str = "conv1d"; - fn f<T: 'static + num_traits::NumAssign + Copy>( - &self, - inp: &[T], - inp_l: &Layout, - k: &[T], - k_l: &Layout, - ) -> Result<Vec<T>> { - // TODO: Optimize this (proper algorithm, simd, multithread, remove bound checks, etc). + fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> { let p = self.0; let inp = &inp[inp_l.start_offset()..]; let k = &k[k_l.start_offset()..]; @@ -1040,25 +1033,35 @@ impl<'a> Map2 for Conv1D<'a> { let dst_elems = p.c_out * l_out * p.b_size; let mut dst = vec![T::zero(); dst_elems]; // The output shape is [b_size, c_out, l_out] + let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in]; for b_idx in 0..p.b_size { - let inp_idx = b_idx * inp_s0; - let dst_idx = b_idx * p.c_out * l_out; + for src_l in 0..p.l_in { + for src_c_idx in 0..p.c_in { + let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2; + inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx] + } + } + } + for offset in 0..p.k_size { for dst_c_idx in 0..p.c_out { - let dst_idx = dst_idx + dst_c_idx * l_out; - for dst_l in 0..l_out { - let dst_idx = dst_idx + dst_l; - let mut d = T::zero(); - for offset in 0..p.k_size { + let dst_idx = dst_c_idx * l_out; + let k_cont = (0..p.c_in) + .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2]) + .collect::<Vec<_>>(); + for b_idx in 0..p.b_size { + let dst_idx = dst_idx + b_idx * p.c_out * l_out; + for dst_l in 0..l_out { + let dst_idx = dst_idx + dst_l; let src_l = (p.stride * dst_l + offset) .saturating_sub(p.padding) .min(p.l_in - 1); - for src_c_idx in 0..p.c_in { - let inp_idx = inp_idx + src_c_idx * inp_s1 + src_l * inp_s2; - let k_idx = dst_c_idx * k_s0 + src_c_idx * k_s1 + offset * k_s2; - d += inp[inp_idx] * k[k_idx] - } + let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..]; + assert!(inp_cont.len() >= p.c_in); + assert!(k_cont.len() >= p.c_in); + let mut d = T::zero(); + unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) } + dst[dst_idx] += d } - dst[dst_idx] = d } } } diff --git a/candle-core/src/cpu_kernels.rs b/candle-core/src/cpu_kernels.rs new file mode 100644 index 00000000..187dc16b --- /dev/null +++ b/candle-core/src/cpu_kernels.rs @@ -0,0 +1,28 @@ +pub trait VecDot: num_traits::NumAssign + Copy { + /// Dot-product of two vectors. + /// + /// # Safety + /// + /// The length of `lhs` and `rhs` have to be at least `len`. `res` has to point to a valid + /// element. + #[inline(always)] + unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { + *res = Self::zero(); + for i in 0..len { + *res += *lhs.add(i) * *rhs.add(i) + } + } +} + +impl VecDot for f32 { + #[inline(always)] + unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { + ggblas::ggml::vec_dot_f32(lhs, rhs, res, len) + } +} + +impl VecDot for f64 {} +impl VecDot for half::bf16 {} +impl VecDot for half::f16 {} +impl VecDot for u8 {} +impl VecDot for u32 {} diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 92929748..5d24b08f 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -54,7 +54,13 @@ impl DType { } pub trait WithDType: - Sized + Copy + num_traits::NumAssign + std::cmp::PartialOrd + std::fmt::Display + 'static + Sized + + Copy + + num_traits::NumAssign + + std::cmp::PartialOrd + + std::fmt::Display + + 'static + + crate::cpu_kernels::VecDot { const DTYPE: DType; diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 016d3806..aba88135 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -40,6 +40,7 @@ pub mod backprop; mod conv; mod convert; pub mod cpu_backend; +pub mod cpu_kernels; #[cfg(feature = "cuda")] pub mod cuda_backend; mod device; |