summaryrefslogtreecommitdiff
path: root/candle-core/src/cpu_backend.rs
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-06-28 21:38:01 +0100
committerlaurent <laurent.mazare@gmail.com>2023-06-28 21:38:01 +0100
commitc583ee0f2cd62d1d820a57e248d5851c5f18145d (patch)
tree5664cb8bbfdc50c6bedd85a3e23be6ded7ef487e /candle-core/src/cpu_backend.rs
parent46c07b924c90dc1fb6ff2d432e6fe16c3da09d72 (diff)
downloadcandle-c583ee0f2cd62d1d820a57e248d5851c5f18145d.tar.gz
candle-c583ee0f2cd62d1d820a57e248d5851c5f18145d.tar.bz2
candle-c583ee0f2cd62d1d820a57e248d5851c5f18145d.zip
Add map2.
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r--candle-core/src/cpu_backend.rs269
1 files changed, 132 insertions, 137 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 7409a90a..7170e470 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -32,33 +32,66 @@ trait Map1 {
}
}
-fn wcond<T: Copy>(
- pred: &[u32],
- layout: &Layout,
- t: &[T],
- layout_t: &Layout,
- f: &[T],
- layout_f: &Layout,
-) -> Vec<T> {
- match (
- layout.contiguous_offsets(),
- layout_t.contiguous_offsets(),
- layout_f.contiguous_offsets(),
- ) {
- (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
- let pred = &pred[o1..o2];
- let t = &t[o_t1..o_t2];
- let f = &f[o_f1..o_f2];
- pred.iter()
- .zip(t.iter().zip(f.iter()))
- .map(|(&p, (&t, &f))| if p > 0 { t } else { f })
- .collect::<Vec<_>>()
+type C = CpuStorage;
+trait Map2 {
+ const OP: &'static str;
+ fn f<T: WithDType + Copy + num_traits::Num + 'static>(
+ &self,
+ v1: &[T],
+ l1: &Layout,
+ v2: &[T],
+ l2: &Layout,
+ ) -> Result<Vec<T>>;
+
+ fn map(
+ &self,
+ v1: &CpuStorage,
+ l1: &Layout,
+ v2: &CpuStorage,
+ l2: &Layout,
+ ) -> Result<CpuStorage> {
+ match (v1, v2) {
+ (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
+ (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
+ (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
+ (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
+ (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
+ _ => Err(Error::DTypeMismatchBinaryOp {
+ lhs: v1.dtype(),
+ rhs: v2.dtype(),
+ op: Self::OP,
+ }),
}
- _ => layout
- .strided_index()
- .zip(layout_t.strided_index().zip(layout_f.strided_index()))
- .map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] })
- .collect::<Vec<_>>(),
+ }
+}
+
+struct WCond<'a>(&'a [u32], &'a Layout);
+
+impl<'a> Map2 for WCond<'a> {
+ const OP: &'static str = "where";
+ fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
+ let vs = match (
+ self.1.contiguous_offsets(),
+ t_l.contiguous_offsets(),
+ f_l.contiguous_offsets(),
+ ) {
+ (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
+ let pred = &self.0[o1..o2];
+ let t = &t[o_t1..o_t2];
+ let f = &f[o_f1..o_f2];
+ pred.iter()
+ .zip(t.iter().zip(f.iter()))
+ .map(|(&p, (&t, &f))| if p > 0 { t } else { f })
+ .collect::<Vec<_>>()
+ }
+ _ => self
+ .1
+ .strided_index()
+ .zip(t_l.strided_index().zip(f_l.strided_index()))
+ .map(|(i_p, (i_t, i_f))| if self.0[i_p] > 0 { t[i_t] } else { f[i_f] })
+ .collect::<Vec<_>>(),
+ };
+ Ok(vs)
}
}
@@ -184,73 +217,79 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
}
}
-fn matmul<T: 'static + num_traits::Num + Copy>(
- lhs: &[T],
- rhs: &[T],
- (b, m, n, k): (usize, usize, usize, usize),
- lhs_l: &Layout,
- rhs_l: &Layout,
-) -> Result<Vec<T>> {
- let lhs = &lhs[lhs_l.start_offset()..];
- let rhs = &rhs[rhs_l.start_offset()..];
- let a_skip: usize = m * k;
- let b_skip: usize = n * k;
- let c_skip: usize = m * n;
-
- let lhs_stride = lhs_l.stride();
- let rhs_stride = rhs_l.stride();
- let rank = lhs_stride.len();
- let lhs_cs = lhs_stride[rank - 1];
- let lhs_rs = lhs_stride[rank - 2];
-
- let rhs_cs = rhs_stride[rank - 1];
- let rhs_rs = rhs_stride[rank - 2];
-
- if lhs_stride.len() > 2 {
- let lhs_batch_stride = &lhs_stride[..rank - 2];
- let rhs_batch_stride = &rhs_stride[..rank - 2];
-
- if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
- // Temporary error before we support abitrary striding.
- return Err(Error::UnexpectedStriding);
+struct MatMul((usize, usize, usize, usize));
+
+impl Map2 for MatMul {
+ const OP: &'static str = "mat_mul";
+ fn f<T: 'static + num_traits::Num + Copy>(
+ &self,
+ lhs: &[T],
+ lhs_l: &Layout,
+ rhs: &[T],
+ rhs_l: &Layout,
+ ) -> Result<Vec<T>> {
+ let (b, m, n, k) = self.0;
+ let lhs = &lhs[lhs_l.start_offset()..];
+ let rhs = &rhs[rhs_l.start_offset()..];
+ let a_skip: usize = m * k;
+ let b_skip: usize = n * k;
+ let c_skip: usize = m * n;
+
+ let lhs_stride = lhs_l.stride();
+ let rhs_stride = rhs_l.stride();
+ let rank = lhs_stride.len();
+ let lhs_cs = lhs_stride[rank - 1];
+ let lhs_rs = lhs_stride[rank - 2];
+
+ let rhs_cs = rhs_stride[rank - 1];
+ let rhs_rs = rhs_stride[rank - 2];
+
+ if lhs_stride.len() > 2 {
+ let lhs_batch_stride = &lhs_stride[..rank - 2];
+ let rhs_batch_stride = &rhs_stride[..rank - 2];
+
+ if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
+ // Temporary error before we support abitrary striding.
+ return Err(Error::UnexpectedStriding);
+ }
}
- }
- let dst_shape: Shape = (m, n).into();
- let dst_strides = dst_shape.stride_contiguous();
- let dst_rs = dst_strides[0];
- let dst_cs = dst_strides[1];
-
- let mut dst = vec![T::zero(); b * m * n];
- for step in 0..b {
- let lhs_p = &lhs[step * a_skip..];
- let rhs_p = &rhs[step * b_skip..];
- let dst_p = &mut dst[step * c_skip..];
- unsafe {
- gemm(
- /* m: usize = */ m,
- /* n: usize = */ n,
- /* k: usize = */ k,
- /* dst: *mut T = */ dst_p.as_mut_ptr(),
- /* dst_cs: isize = */ dst_cs as isize,
- /* dst_rs: isize = */ dst_rs as isize,
- /* read_dst: bool = */ false,
- /* lhs: *const T = */ lhs_p.as_ptr(),
- /* lhs_cs: isize = */ lhs_cs as isize,
- /* lhs_rs: isize = */ lhs_rs as isize,
- /* rhs: *const T = */ rhs_p.as_ptr(),
- /* rhs_cs: isize = */ rhs_cs as isize,
- /* rhs_rs: isize = */ rhs_rs as isize,
- /* alpha: T = */ T::zero(),
- /* beta: T = */ T::one(),
- /* conj_dst: bool = */ false,
- /* conj_lhs: bool = */ false,
- /* conj_rhs: bool = */ false,
- Parallelism::Rayon(crate::utils::get_num_threads()),
- )
+ let dst_shape: Shape = (m, n).into();
+ let dst_strides = dst_shape.stride_contiguous();
+ let dst_rs = dst_strides[0];
+ let dst_cs = dst_strides[1];
+
+ let mut dst = vec![T::zero(); b * m * n];
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ gemm(
+ /* m: usize = */ m,
+ /* n: usize = */ n,
+ /* k: usize = */ k,
+ /* dst: *mut T = */ dst_p.as_mut_ptr(),
+ /* dst_cs: isize = */ dst_cs as isize,
+ /* dst_rs: isize = */ dst_rs as isize,
+ /* read_dst: bool = */ false,
+ /* lhs: *const T = */ lhs_p.as_ptr(),
+ /* lhs_cs: isize = */ lhs_cs as isize,
+ /* lhs_rs: isize = */ lhs_rs as isize,
+ /* rhs: *const T = */ rhs_p.as_ptr(),
+ /* rhs_cs: isize = */ rhs_cs as isize,
+ /* rhs_rs: isize = */ rhs_rs as isize,
+ /* alpha: T = */ T::zero(),
+ /* beta: T = */ T::one(),
+ /* conj_dst: bool = */ false,
+ /* conj_lhs: bool = */ false,
+ /* conj_rhs: bool = */ false,
+ Parallelism::Rayon(crate::utils::get_num_threads()),
+ )
+ }
}
+ Ok(dst)
}
- Ok(dst)
}
impl CpuStorage {
@@ -574,39 +613,13 @@ impl CpuStorage {
&self,
layout: &Layout,
t: &Self,
- layout_t: &Layout,
+ t_l: &Layout,
f: &Self,
- layout_f: &Layout,
+ f_l: &Layout,
) -> Result<Self> {
// TODO: Support types that could be casted to a boolean.
let pred = self.as_slice::<u32>()?;
- match (t, f) {
- (Self::BF16(t), Self::BF16(f)) => {
- let data = wcond(pred, layout, t, layout_t, f, layout_f);
- Ok(Self::BF16(data))
- }
- (Self::F16(t), Self::F16(f)) => {
- let data = wcond(pred, layout, t, layout_t, f, layout_f);
- Ok(Self::F16(data))
- }
- (Self::F32(t), Self::F32(f)) => {
- let data = wcond(pred, layout, t, layout_t, f, layout_f);
- Ok(Self::F32(data))
- }
- (Self::F64(t), Self::F64(f)) => {
- let data = wcond(pred, layout, t, layout_t, f, layout_f);
- Ok(Self::F64(data))
- }
- (Self::U32(t), Self::U32(f)) => {
- let data = wcond(pred, layout, t, layout_t, f, layout_f);
- Ok(Self::U32(data))
- }
- _ => Err(Error::DTypeMismatchBinaryOp {
- lhs: t.dtype(),
- rhs: f.dtype(),
- op: "where_cond",
- }),
- }
+ WCond(pred, layout).map(t, t_l, f, f_l)
}
pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
@@ -628,25 +641,7 @@ impl CpuStorage {
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
- match (self, rhs) {
- (CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => {
- let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?;
- Ok(Self::F16(dst))
- }
- (CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
- let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?;
- Ok(Self::F32(dst))
- }
- (CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
- let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?;
- Ok(Self::F64(dst))
- }
- _ => Err(Error::DTypeMismatchBinaryOp {
- lhs: self.dtype(),
- rhs: rhs.dtype(),
- op: "matmul",
- }),
- }
+ MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
}
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {