summaryrefslogtreecommitdiff
path: root/candle-core/src/cpu
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-05 16:22:27 +0200
committerGitHub <noreply@github.com>2023-09-05 15:22:27 +0100
commit6615daf2425bbf33f0e1f97d2a18534e6bdb9fc3 (patch)
tree9c61a62040ee8e2e384a97ce647d810a85a1f9b1 /candle-core/src/cpu
parent1c9e5394a5056aadc948f9330ea31fea4972e65e (diff)
downloadcandle-6615daf2425bbf33f0e1f97d2a18534e6bdb9fc3.tar.gz
candle-6615daf2425bbf33f0e1f97d2a18534e6bdb9fc3.tar.bz2
candle-6615daf2425bbf33f0e1f97d2a18534e6bdb9fc3.zip
Tweaks to softmax. (#745)
Diffstat (limited to 'candle-core/src/cpu')
-rw-r--r--candle-core/src/cpu/kernels.rs95
1 files changed, 81 insertions, 14 deletions
diff --git a/candle-core/src/cpu/kernels.rs b/candle-core/src/cpu/kernels.rs
index 97e195ef..527646d6 100644
--- a/candle-core/src/cpu/kernels.rs
+++ b/candle-core/src/cpu/kernels.rs
@@ -1,4 +1,7 @@
-pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
+pub trait VecOps: num_traits::NumAssign + Copy {
+ fn min(self, rhs: Self) -> Self;
+ fn max(self, rhs: Self) -> Self;
+
/// Dot-product of two vectors.
///
/// # Safety
@@ -37,10 +40,7 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs;
for i in 1..len {
- let x = *xs.add(i);
- if x > *res {
- *res = x
- }
+ *res = (*res).max(*xs.add(i))
}
}
@@ -54,16 +54,23 @@ pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
*res = *xs;
for i in 1..len {
- let x = *xs.add(i);
- if x < *res {
- *res = x
- }
+ *res = (*res).min(*xs.add(i))
}
}
}
impl VecOps for f32 {
#[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+
+ #[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
super::vec_dot_f32(lhs, rhs, res, len)
}
@@ -76,6 +83,16 @@ impl VecOps for f32 {
impl VecOps for half::f16 {
#[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+
+ #[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
let mut res_f32 = 0f32;
super::vec_dot_f16(lhs, rhs, &mut res_f32, len);
@@ -83,11 +100,61 @@ impl VecOps for half::f16 {
}
}
-impl VecOps for f64 {}
-impl VecOps for half::bf16 {}
-impl VecOps for u8 {}
-impl VecOps for u32 {}
-impl VecOps for i64 {}
+impl VecOps for f64 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+}
+impl VecOps for half::bf16 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ Self::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ Self::max(self, other)
+ }
+}
+impl VecOps for u8 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ <Self as Ord>::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ <Self as Ord>::max(self, other)
+ }
+}
+impl VecOps for u32 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ <Self as Ord>::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ <Self as Ord>::max(self, other)
+ }
+}
+impl VecOps for i64 {
+ #[inline(always)]
+ fn min(self, other: Self) -> Self {
+ <Self as Ord>::min(self, other)
+ }
+
+ #[inline(always)]
+ fn max(self, other: Self) -> Self {
+ <Self as Ord>::max(self, other)
+ }
+}
#[inline(always)]
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {