summaryrefslogtreecommitdiff
path: root/src/cpu_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu_backend.rs')
-rw-r--r--src/cpu_backend.rs109
1 files changed, 82 insertions, 27 deletions
diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs
index c71536ed..0eb4270a 100644
--- a/src/cpu_backend.rs
+++ b/src/cpu_backend.rs
@@ -1,6 +1,6 @@
use crate::storage::{BinaryOp, UnaryOp};
use crate::{DType, Error, Result, Shape, StridedIndex};
-use ggblas::batched_sgemm;
+use gemm::{gemm, Parallelism};
// TODO: Think about whether we would be better off with a dtype and
// a buffer as an owned slice of bytes.
@@ -113,28 +113,83 @@ impl CpuStorage {
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
- println!("rhs {rhs:?}");
- println!("lhs_stride {lhs_stride:?}");
- println!("rhs_stride {rhs_stride:?}");
- // todo!("matmul");
let a_skip: usize = m * k;
let b_skip: usize = n * k;
let c_skip: usize = m * n;
- let mut c = Self::F32(vec![0.0; b * m * n]);
-
- batched_sgemm(
- self.as_slice()?,
- a_skip,
- rhs.as_slice()?,
- b_skip,
- c.as_mut_slice()?,
- c_skip,
- m,
- n,
- k,
- b,
- );
+ let rank = lhs_stride.len();
+ let lhs_cs = lhs_stride[rank - 1];
+ let lhs_rs = lhs_stride[rank - 2];
+
+ let rhs_cs = rhs_stride[rank - 1];
+ let rhs_rs = rhs_stride[rank - 2];
+
+ if lhs_stride.len() > 2 {
+ let lhs_batch_stride = &lhs_stride[..rank - 2];
+ let rhs_batch_stride = &rhs_stride[..rank - 2];
+
+ if lhs_batch_stride != &[a_skip] || rhs_batch_stride != &[b_skip] {
+ // Temporary error before we support abitrary striding.
+ return Err(Error::UnexpectedStriding);
+ }
+ }
+
+ let mut dst = vec![0.0; b * m * n];
+
+ let dst_shape: Shape = (m, n).into();
+ let dst_strides = dst_shape.stride_contiguous();
+ let dst_rs = dst_strides[0];
+ let dst_cs = dst_strides[1];
+
+ for step in 0..b {
+ let lhs_p = &self.as_slice::<f32>()?[step * a_skip..];
+ let rhs_p = &rhs.as_slice::<f32>()?[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ gemm(
+ // m: usize,
+ m,
+ // n: usize,
+ n,
+ // k: usize,
+ k,
+ // dst: *mut T,
+ dst_p.as_mut_ptr(),
+ // dst_cs: isize,
+ dst_cs as isize,
+ // dst_rs: isize,
+ dst_rs as isize,
+ // read_dst: bool,
+ false,
+ // lhs: *const T,
+ lhs_p.as_ptr(),
+ // lhs_cs: isize,
+ lhs_cs as isize,
+ // lhs_rs: isize,
+ lhs_rs as isize,
+ // rhs: *const T,
+ rhs_p.as_ptr(),
+ // rhs_cs: isize,
+ rhs_cs as isize,
+ // rhs_rs: isize,
+ rhs_rs as isize,
+ // alpha: T,
+ 1.0,
+ // beta: T,
+ 1.0,
+ // conj_dst: bool,
+ false,
+ // conj_lhs: bool,
+ false,
+ // conj_rhs: bool,
+ true,
+ // parallelism: Parallelism
+ Parallelism::None,
+ )
+ }
+ }
+
+ let c = Self::F32(dst);
Ok(c)
}
@@ -175,31 +230,31 @@ mod tests {
#[test]
fn simple_matmul() -> Result<()> {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
- let a = Tensor::from_slice(&data, (2, 2), Device::Cpu)?;
+ let a = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;
let data = vec![1.0f32, 2.0, 3.0, 4.0];
- let b = Tensor::from_slice(&data, (2, 2), Device::Cpu)?;
+ let b = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]);
let data = vec![1.0f32, 2.0];
- let a = Tensor::from_slice(&data, (2, 1), Device::Cpu)?;
+ let a = Tensor::from_slice(&data, (2, 1), &Device::Cpu)?;
let data = vec![3.0f32, 4.0];
- let b = Tensor::from_slice(&data, (1, 2), Device::Cpu)?;
+ let b = Tensor::from_slice(&data, (1, 2), &Device::Cpu)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
- let a = Tensor::from_slice(&data, (2, 3), Device::Cpu)?;
+ let a = Tensor::from_slice(&data, (2, 3), &Device::Cpu)?;
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
- let b = Tensor::from_slice(&data, (3, 2), Device::Cpu)?;
+ let b = Tensor::from_slice(&data, (3, 2), &Device::Cpu)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
- let a = Tensor::from_slice(&data, (2, 2, 3), Device::Cpu)?;
+ let a = Tensor::from_slice(&data, (2, 2, 3), &Device::Cpu)?;
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
- let b = Tensor::from_slice(&data, (2, 3, 2), Device::Cpu)?;
+ let b = Tensor::from_slice(&data, (2, 3, 2), &Device::Cpu)?;
let c = a.matmul(&b)?;
assert_eq!(
c.to_vec3::<f32>()?,