diff options
-rw-r--r-- | candle-core/src/cpu/kernels.rs | 95 | ||||
-rw-r--r-- | candle-nn/src/ops.rs | 8 |
2 files changed, 84 insertions, 19 deletions
diff --git a/candle-core/src/cpu/kernels.rs b/candle-core/src/cpu/kernels.rs index 97e195ef..527646d6 100644 --- a/candle-core/src/cpu/kernels.rs +++ b/candle-core/src/cpu/kernels.rs @@ -1,4 +1,7 @@ -pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy { +pub trait VecOps: num_traits::NumAssign + Copy { + fn min(self, rhs: Self) -> Self; + fn max(self, rhs: Self) -> Self; + /// Dot-product of two vectors. /// /// # Safety @@ -37,10 +40,7 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy { unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) { *res = *xs; for i in 1..len { - let x = *xs.add(i); - if x > *res { - *res = x - } + *res = (*res).max(*xs.add(i)) } } @@ -54,16 +54,23 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy { unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) { *res = *xs; for i in 1..len { - let x = *xs.add(i); - if x < *res { - *res = x - } + *res = (*res).min(*xs.add(i)) } } } impl VecOps for f32 { #[inline(always)] + fn min(self, other: Self) -> Self { + Self::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + Self::max(self, other) + } + + #[inline(always)] unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { super::vec_dot_f32(lhs, rhs, res, len) } @@ -76,6 +83,16 @@ impl VecOps for f32 { impl VecOps for half::f16 { #[inline(always)] + fn min(self, other: Self) -> Self { + Self::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + Self::max(self, other) + } + + #[inline(always)] unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) { let mut res_f32 = 0f32; super::vec_dot_f16(lhs, rhs, &mut res_f32, len); @@ -83,11 +100,61 @@ impl VecOps for half::f16 { } } -impl VecOps for f64 {} -impl VecOps for half::bf16 {} -impl VecOps for u8 {} -impl VecOps for u32 {} -impl VecOps for i64 {} +impl VecOps for f64 { + #[inline(always)] + fn min(self, other: Self) -> Self { + Self::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + Self::max(self, other) + } +} +impl VecOps for half::bf16 { + #[inline(always)] + fn min(self, other: Self) -> Self { + Self::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + Self::max(self, other) + } +} +impl VecOps for u8 { + #[inline(always)] + fn min(self, other: Self) -> Self { + <Self as Ord>::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + <Self as Ord>::max(self, other) + } +} +impl VecOps for u32 { + #[inline(always)] + fn min(self, other: Self) -> Self { + <Self as Ord>::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + <Self as Ord>::max(self, other) + } +} +impl VecOps for i64 { + #[inline(always)] + fn min(self, other: Self) -> Self { + <Self as Ord>::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + <Self as Ord>::max(self, other) + } +} #[inline(always)] pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) { diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 55da46f8..73214077 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -103,14 +103,12 @@ impl candle::CustomOp1 for SoftmaxLastDim { .zip(dst.par_chunks_mut(dim_m1)) .for_each(|(src, dst)| { let mut max = T::neg_infinity(); - for &s in src.iter() { - max = T::max(s, max) - } - let mut sum_exp = T::zero(); + unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) }; for (s, d) in src.iter().zip(dst.iter_mut()) { *d = (*s - max).exp(); - sum_exp += *d } + let mut sum_exp = T::zero(); + unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) }; for d in dst.iter_mut() { *d /= sum_exp } |