summaryrefslogtreecommitdiff
path: root/candle-core/src/cpu/kernels.rs
blob: 97e195efe9ee37d01d099ed481657b0c187a56d7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
pub trait VecOps: num_traits::NumAssign + PartialOrd + Copy {
    /// Dot-product of two vectors.
    ///
    /// # Safety
    ///
    /// The length of `lhs` and `rhs` have to be at least `len`. `res` has to point to a valid
    /// element.
    #[inline(always)]
    unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
        *res = Self::zero();
        for i in 0..len {
            *res += *lhs.add(i) * *rhs.add(i)
        }
    }

    /// Sum of all elements in a vector.
    ///
    /// # Safety
    ///
    /// The length of `xs` must be at least `len`. `res` has to point to a valid
    /// element.
    #[inline(always)]
    unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
        *res = Self::zero();
        for i in 0..len {
            *res += *xs.add(i)
        }
    }

    /// Maximum element in a non-empty vector.
    ///
    /// # Safety
    ///
    /// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
    /// element.
    #[inline(always)]
    unsafe fn vec_reduce_max(xs: *const Self, res: *mut Self, len: usize) {
        *res = *xs;
        for i in 1..len {
            let x = *xs.add(i);
            if x > *res {
                *res = x
            }
        }
    }

    /// Minimum element in a non-empty vector.
    ///
    /// # Safety
    ///
    /// The length of `xs` must be at least `len` and positive. `res` has to point to a valid
    /// element.
    #[inline(always)]
    unsafe fn vec_reduce_min(xs: *const Self, res: *mut Self, len: usize) {
        *res = *xs;
        for i in 1..len {
            let x = *xs.add(i);
            if x < *res {
                *res = x
            }
        }
    }
}

impl VecOps for f32 {
    #[inline(always)]
    unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
        super::vec_dot_f32(lhs, rhs, res, len)
    }

    #[inline(always)]
    unsafe fn vec_reduce_sum(xs: *const Self, res: *mut Self, len: usize) {
        super::vec_sum(xs, res, len)
    }
}

impl VecOps for half::f16 {
    #[inline(always)]
    unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
        let mut res_f32 = 0f32;
        super::vec_dot_f16(lhs, rhs, &mut res_f32, len);
        *res = half::f16::from_f32(res_f32);
    }
}

impl VecOps for f64 {}
impl VecOps for half::bf16 {}
impl VecOps for u8 {}
impl VecOps for u32 {}
impl VecOps for i64 {}

#[inline(always)]
pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
    if n_threads == 1 {
        func(0)
    } else {
        rayon::scope(|s| {
            for thread_idx in 0..n_threads {
                let func = &func;
                s.spawn(move |_| func(thread_idx));
            }
        })
    }
}

#[inline(always)]
pub fn par_range(lo: usize, up: usize, n_threads: usize, func: impl Fn(usize) + Send + Sync) {
    if n_threads == 1 {
        for i in lo..up {
            func(i)
        }
    } else {
        rayon::scope(|s| {
            for thread_idx in 0..n_threads {
                let func = &func;
                s.spawn(move |_| {
                    for i in (thread_idx..up).step_by(n_threads) {
                        func(i)
                    }
                });
            }
        })
    }
}