summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml1
-rw-r--r--candle-core/Cargo.toml1
-rw-r--r--candle-core/src/cpu_backend.rs34
-rw-r--r--candle-core/src/cpu_kernels.rs34
-rw-r--r--candle-core/src/dtype.rs2
5 files changed, 64 insertions, 8 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 5b17f336..915e6314 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -43,6 +43,7 @@ num_cpus = "1.15.0"
num-traits = "0.2.15"
rand = "0.8.5"
rand_distr = "0.4.3"
+rayon = "1.7.0"
safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_json = "1.0.99"
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index b5d74e12..36e018b9 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -24,6 +24,7 @@ num-traits = { workspace = true }
num_cpus = { workspace = true }
rand = { workspace = true }
rand_distr = { workspace = true }
+rayon = { workspace = true }
safetensors = { workspace = true }
thiserror = { workspace = true }
zip = { workspace = true }
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 6d129680..86f14e32 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -1032,7 +1032,7 @@ impl<'a> Map2 for Conv1D<'a> {
let l_out = p.l_out();
let dst_elems = p.c_out * l_out * p.b_size;
// The output shape is [b_size, c_out, l_out]
- let mut dst = vec![T::zero(); dst_elems];
+ let dst = vec![T::zero(); dst_elems];
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
@@ -1045,8 +1045,10 @@ impl<'a> Map2 for Conv1D<'a> {
}
}
+ let num_threads = crate::utils::get_num_threads();
+
for offset in 0..p.k_size {
- for dst_c_idx in 0..p.c_out {
+ crate::cpu_kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
let dst_idx = dst_c_idx * l_out;
let k_cont = (0..p.c_in)
.map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
@@ -1063,10 +1065,17 @@ impl<'a> Map2 for Conv1D<'a> {
assert!(k_cont.len() >= p.c_in);
let mut d = T::zero();
unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
- dst[dst_idx] += d
+ let dst_p = dst.as_ptr();
+ // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
+ // the different tasks so no two threads can try to write at the same
+ // location.
+ unsafe {
+ let ptr = dst_p.add(dst_idx) as *mut T;
+ *ptr += d
+ }
}
}
- }
+ })
}
Ok(dst)
}
@@ -1085,7 +1094,7 @@ impl<'a> Map2 for Conv2D<'a> {
let (out_h, out_w) = (p.out_h(), p.out_w());
// Output shape: [b_size, c_out, out_h, out_w].
- let mut dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
+ let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
// TODO: Avoid making this copy if `inp` already has the appropriate layout.
let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
@@ -1105,9 +1114,11 @@ impl<'a> Map2 for Conv2D<'a> {
}
}
+ let num_threads = crate::utils::get_num_threads();
+
for offset_h in 0..p.k_h {
for offset_w in 0..p.k_w {
- for dst_c_idx in 0..p.c_out {
+ crate::cpu_kernels::par_range(0, p.c_out, num_threads, |dst_c_idx| {
let dst_idx = dst_c_idx * out_w * out_h;
let k_cont = (0..p.c_in)
.map(|c_in_idx| {
@@ -1137,11 +1148,18 @@ impl<'a> Map2 for Conv2D<'a> {
unsafe {
T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
}
- dst[dst_idx] += d
+ let dst_p = dst.as_ptr();
+ // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
+ // the different tasks so no two threads can try to write at the same
+ // location.
+ unsafe {
+ let ptr = dst_p.add(dst_idx) as *mut T;
+ *ptr += d
+ }
}
}
}
- }
+ });
}
}
diff --git a/candle-core/src/cpu_kernels.rs b/candle-core/src/cpu_kernels.rs
index 187dc16b..75509ba9 100644
--- a/candle-core/src/cpu_kernels.rs
+++ b/candle-core/src/cpu_kernels.rs
@@ -26,3 +26,37 @@ impl VecDot for half::bf16 {}
impl VecDot for half::f16 {}
impl VecDot for u8 {}
impl VecDot for u32 {}
+
+#[inline(always)]
+pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
+ if n_threads == 1 {
+ func(0)
+ } else {
+ rayon::scope(|s| {
+ for thread_idx in 0..n_threads {
+ let func = &func;
+ s.spawn(move |_| func(thread_idx));
+ }
+ })
+ }
+}
+
+#[inline(always)]
+pub fn par_range(lo: usize, up: usize, n_threads: usize, func: impl Fn(usize) + Send + Sync) {
+ if n_threads == 1 {
+ for i in lo..up {
+ func(i)
+ }
+ } else {
+ rayon::scope(|s| {
+ for thread_idx in 0..n_threads {
+ let func = &func;
+ s.spawn(move |_| {
+ for i in (thread_idx..up).step_by(n_threads) {
+ func(i)
+ }
+ });
+ }
+ })
+ }
+}
diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs
index 5d24b08f..94318b7f 100644
--- a/candle-core/src/dtype.rs
+++ b/candle-core/src/dtype.rs
@@ -60,6 +60,8 @@ pub trait WithDType:
+ std::cmp::PartialOrd
+ std::fmt::Display
+ 'static
+ + Send
+ + Sync
+ crate::cpu_kernels::VecDot
{
const DTYPE: DType;