summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml1
-rw-r--r--candle-core/Cargo.toml1
-rw-r--r--candle-core/src/cpu_backend.rs45
-rw-r--r--candle-core/src/cpu_kernels.rs28
-rw-r--r--candle-core/src/dtype.rs8
-rw-r--r--candle-core/src/lib.rs1
6 files changed, 62 insertions, 22 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 850b13ef..c0d87680 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -31,6 +31,7 @@ clap = { version = "4.2.4", features = ["derive"] }
cudarc = { version = "0.9.13", features = ["f16"] }
# TODO: Switch back to the official gemm implementation once it has caught up.
gemm = { version = "0.15.6", package = "candle-gemm" }
+ggblas = "0.1.2"
hf-hub = "0.2.0"
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index 7411592e..bf57a91c 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -15,6 +15,7 @@ byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.1.0", optional = true }
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
+ggblas = { workspace = true }
half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
libc = { workspace = true, optional = true }
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index d4f5fcdc..250e2721 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -1023,14 +1023,7 @@ struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
impl<'a> Map2 for Conv1D<'a> {
const OP: &'static str = "conv1d";
- fn f<T: 'static + num_traits::NumAssign + Copy>(
- &self,
- inp: &[T],
- inp_l: &Layout,
- k: &[T],
- k_l: &Layout,
- ) -> Result<Vec<T>> {
- // TODO: Optimize this (proper algorithm, simd, multithread, remove bound checks, etc).
+ fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
let k = &k[k_l.start_offset()..];
@@ -1040,25 +1033,35 @@ impl<'a> Map2 for Conv1D<'a> {
let dst_elems = p.c_out * l_out * p.b_size;
let mut dst = vec![T::zero(); dst_elems];
// The output shape is [b_size, c_out, l_out]
+ let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
for b_idx in 0..p.b_size {
- let inp_idx = b_idx * inp_s0;
- let dst_idx = b_idx * p.c_out * l_out;
+ for src_l in 0..p.l_in {
+ for src_c_idx in 0..p.c_in {
+ let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2;
+ inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx]
+ }
+ }
+ }
+ for offset in 0..p.k_size {
for dst_c_idx in 0..p.c_out {
- let dst_idx = dst_idx + dst_c_idx * l_out;
- for dst_l in 0..l_out {
- let dst_idx = dst_idx + dst_l;
- let mut d = T::zero();
- for offset in 0..p.k_size {
+ let dst_idx = dst_c_idx * l_out;
+ let k_cont = (0..p.c_in)
+ .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
+ .collect::<Vec<_>>();
+ for b_idx in 0..p.b_size {
+ let dst_idx = dst_idx + b_idx * p.c_out * l_out;
+ for dst_l in 0..l_out {
+ let dst_idx = dst_idx + dst_l;
let src_l = (p.stride * dst_l + offset)
.saturating_sub(p.padding)
.min(p.l_in - 1);
- for src_c_idx in 0..p.c_in {
- let inp_idx = inp_idx + src_c_idx * inp_s1 + src_l * inp_s2;
- let k_idx = dst_c_idx * k_s0 + src_c_idx * k_s1 + offset * k_s2;
- d += inp[inp_idx] * k[k_idx]
- }
+ let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..];
+ assert!(inp_cont.len() >= p.c_in);
+ assert!(k_cont.len() >= p.c_in);
+ let mut d = T::zero();
+ unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
+ dst[dst_idx] += d
}
- dst[dst_idx] = d
}
}
}
diff --git a/candle-core/src/cpu_kernels.rs b/candle-core/src/cpu_kernels.rs
new file mode 100644
index 00000000..187dc16b
--- /dev/null
+++ b/candle-core/src/cpu_kernels.rs
@@ -0,0 +1,28 @@
+pub trait VecDot: num_traits::NumAssign + Copy {
+ /// Dot-product of two vectors.
+ ///
+ /// # Safety
+ ///
+ /// The length of `lhs` and `rhs` have to be at least `len`. `res` has to point to a valid
+ /// element.
+ #[inline(always)]
+ unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
+ *res = Self::zero();
+ for i in 0..len {
+ *res += *lhs.add(i) * *rhs.add(i)
+ }
+ }
+}
+
+impl VecDot for f32 {
+ #[inline(always)]
+ unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
+ ggblas::ggml::vec_dot_f32(lhs, rhs, res, len)
+ }
+}
+
+impl VecDot for f64 {}
+impl VecDot for half::bf16 {}
+impl VecDot for half::f16 {}
+impl VecDot for u8 {}
+impl VecDot for u32 {}
diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs
index 92929748..5d24b08f 100644
--- a/candle-core/src/dtype.rs
+++ b/candle-core/src/dtype.rs
@@ -54,7 +54,13 @@ impl DType {
}
pub trait WithDType:
- Sized + Copy + num_traits::NumAssign + std::cmp::PartialOrd + std::fmt::Display + 'static
+ Sized
+ + Copy
+ + num_traits::NumAssign
+ + std::cmp::PartialOrd
+ + std::fmt::Display
+ + 'static
+ + crate::cpu_kernels::VecDot
{
const DTYPE: DType;
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index 016d3806..aba88135 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -40,6 +40,7 @@ pub mod backprop;
mod conv;
mod convert;
pub mod cpu_backend;
+pub mod cpu_kernels;
#[cfg(feature = "cuda")]
pub mod cuda_backend;
mod device;