diff options
-rw-r--r-- | Cargo.toml | 1 | ||||
-rw-r--r-- | candle-core/Cargo.toml | 1 | ||||
-rw-r--r-- | candle-core/src/cpu_backend.rs | 34 | ||||
-rw-r--r-- | candle-core/src/cpu_kernels.rs | 34 | ||||
-rw-r--r-- | candle-core/src/dtype.rs | 2 |
5 files changed, 64 insertions, 8 deletions
@@ -43,6 +43,7 @@ num_cpus = "1.15.0" num-traits = "0.2.15" rand = "0.8.5" rand_distr = "0.4.3" +rayon = "1.7.0" safetensors = "0.3.1" serde = { version = "1.0.171", features = ["derive"] } serde_json = "1.0.99" diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index b5d74e12..36e018b9 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -24,6 +24,7 @@ num-traits = { workspace = true } num_cpus = { workspace = true } rand = { workspace = true } rand_distr = { workspace = true } +rayon = { workspace = true } safetensors = { workspace = true } thiserror = { workspace = true } zip = { workspace = true } diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 6d129680..86f14e32 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1032,7 +1032,7 @@ impl<'a> Map2 for Conv1D<'a> { let l_out = p.l_out(); let dst_elems = p.c_out * l_out * p.b_size; // The output shape is [b_size, c_out, l_out] - let mut dst = vec![T::zero(); dst_elems]; + let dst = vec![T::zero(); dst_elems]; // TODO: Avoid making this copy if `inp` already has the appropriate layout. let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in]; @@ -1045,8 +1045,10 @@ impl<'a> Map2 for Conv1D<'a> { } } + let num_threads = crate::utils::get_num_threads(); + for offset in 0..p.k_size { - for dst_c_idx in 0..p.c_out { + crate::cpu_kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| { let dst_idx = dst_c_idx * l_out; let k_cont = (0..p.c_in) .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2]) @@ -1063,10 +1065,17 @@ impl<'a> Map2 for Conv1D<'a> { assert!(k_cont.len() >= p.c_in); let mut d = T::zero(); unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) } - dst[dst_idx] += d + let dst_p = dst.as_ptr(); + // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise + // the different tasks so no two threads can try to write at the same + // location. + unsafe { + let ptr = dst_p.add(dst_idx) as *mut T; + *ptr += d + } } } - } + }) } Ok(dst) } @@ -1085,7 +1094,7 @@ impl<'a> Map2 for Conv2D<'a> { let (out_h, out_w) = (p.out_h(), p.out_w()); // Output shape: [b_size, c_out, out_h, out_w]. - let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w]; + let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w]; // TODO: Avoid making this copy if `inp` already has the appropriate layout. let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w]; @@ -1105,9 +1114,11 @@ impl<'a> Map2 for Conv2D<'a> { } } + let num_threads = crate::utils::get_num_threads(); + for offset_h in 0..p.k_h { for offset_w in 0..p.k_w { - for dst_c_idx in 0..p.c_out { + crate::cpu_kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| { let dst_idx = dst_c_idx * out_w * out_h; let k_cont = (0..p.c_in) .map(|c_in_idx| { @@ -1137,11 +1148,18 @@ impl<'a> Map2 for Conv2D<'a> { unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) } - dst[dst_idx] += d + let dst_p = dst.as_ptr(); + // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise + // the different tasks so no two threads can try to write at the same + // location. + unsafe { + let ptr = dst_p.add(dst_idx) as *mut T; + *ptr += d + } } } } - } + }); } } diff --git a/candle-core/src/cpu_kernels.rs b/candle-core/src/cpu_kernels.rs index 187dc16b..75509ba9 100644 --- a/candle-core/src/cpu_kernels.rs +++ b/candle-core/src/cpu_kernels.rs @@ -26,3 +26,37 @@ impl VecDot for half::bf16 {} impl VecDot for half::f16 {} impl VecDot for u8 {} impl VecDot for u32 {} + +#[inline(always)] +pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) { + if n_threads == 1 { + func(0) + } else { + rayon::scope(|s| { + for thread_idx in 0..n_threads { + let func = &func; + s.spawn(move |_| func(thread_idx)); + } + }) + } +} + +#[inline(always)] +pub fn par_range(lo: usize, up: usize, n_threads: usize, func: impl Fn(usize) + Send + Sync) { + if n_threads == 1 { + for i in lo..up { + func(i) + } + } else { + rayon::scope(|s| { + for thread_idx in 0..n_threads { + let func = &func; + s.spawn(move |_| { + for i in (thread_idx..up).step_by(n_threads) { + func(i) + } + }); + } + }) + } +} diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 5d24b08f..94318b7f 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -60,6 +60,8 @@ pub trait WithDType: + std::cmp::PartialOrd + std::fmt::Display + 'static + + Send + + Sync + crate::cpu_kernels::VecDot { const DTYPE: DType; |