diff options
Diffstat (limited to 'src/cpu_backend.rs')
-rw-r--r-- | src/cpu_backend.rs | 109 |
1 files changed, 82 insertions, 27 deletions
diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index c71536ed..0eb4270a 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -1,6 +1,6 @@ use crate::storage::{BinaryOp, UnaryOp}; use crate::{DType, Error, Result, Shape, StridedIndex}; -use ggblas::batched_sgemm; +use gemm::{gemm, Parallelism}; // TODO: Think about whether we would be better off with a dtype and // a buffer as an owned slice of bytes. @@ -113,28 +113,83 @@ impl CpuStorage { lhs_stride: &[usize], rhs_stride: &[usize], ) -> Result<Self> { - println!("rhs {rhs:?}"); - println!("lhs_stride {lhs_stride:?}"); - println!("rhs_stride {rhs_stride:?}"); - // todo!("matmul"); let a_skip: usize = m * k; let b_skip: usize = n * k; let c_skip: usize = m * n; - let mut c = Self::F32(vec![0.0; b * m * n]); - - batched_sgemm( - self.as_slice()?, - a_skip, - rhs.as_slice()?, - b_skip, - c.as_mut_slice()?, - c_skip, - m, - n, - k, - b, - ); + let rank = lhs_stride.len(); + let lhs_cs = lhs_stride[rank - 1]; + let lhs_rs = lhs_stride[rank - 2]; + + let rhs_cs = rhs_stride[rank - 1]; + let rhs_rs = rhs_stride[rank - 2]; + + if lhs_stride.len() > 2 { + let lhs_batch_stride = &lhs_stride[..rank - 2]; + let rhs_batch_stride = &rhs_stride[..rank - 2]; + + if lhs_batch_stride != &[a_skip] || rhs_batch_stride != &[b_skip] { + // Temporary error before we support abitrary striding. + return Err(Error::UnexpectedStriding); + } + } + + let mut dst = vec![0.0; b * m * n]; + + let dst_shape: Shape = (m, n).into(); + let dst_strides = dst_shape.stride_contiguous(); + let dst_rs = dst_strides[0]; + let dst_cs = dst_strides[1]; + + for step in 0..b { + let lhs_p = &self.as_slice::<f32>()?[step * a_skip..]; + let rhs_p = &rhs.as_slice::<f32>()?[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + gemm( + // m: usize, + m, + // n: usize, + n, + // k: usize, + k, + // dst: *mut T, + dst_p.as_mut_ptr(), + // dst_cs: isize, + dst_cs as isize, + // dst_rs: isize, + dst_rs as isize, + // read_dst: bool, + false, + // lhs: *const T, + lhs_p.as_ptr(), + // lhs_cs: isize, + lhs_cs as isize, + // lhs_rs: isize, + lhs_rs as isize, + // rhs: *const T, + rhs_p.as_ptr(), + // rhs_cs: isize, + rhs_cs as isize, + // rhs_rs: isize, + rhs_rs as isize, + // alpha: T, + 1.0, + // beta: T, + 1.0, + // conj_dst: bool, + false, + // conj_lhs: bool, + false, + // conj_rhs: bool, + true, + // parallelism: Parallelism + Parallelism::None, + ) + } + } + + let c = Self::F32(dst); Ok(c) } @@ -175,31 +230,31 @@ mod tests { #[test] fn simple_matmul() -> Result<()> { let data = vec![1.0f32, 2.0, 3.0, 4.0]; - let a = Tensor::from_slice(&data, (2, 2), Device::Cpu)?; + let a = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?; let data = vec![1.0f32, 2.0, 3.0, 4.0]; - let b = Tensor::from_slice(&data, (2, 2), Device::Cpu)?; + let b = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?; let c = a.matmul(&b)?; assert_eq!(c.to_vec2::<f32>()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]); let data = vec![1.0f32, 2.0]; - let a = Tensor::from_slice(&data, (2, 1), Device::Cpu)?; + let a = Tensor::from_slice(&data, (2, 1), &Device::Cpu)?; let data = vec![3.0f32, 4.0]; - let b = Tensor::from_slice(&data, (1, 2), Device::Cpu)?; + let b = Tensor::from_slice(&data, (1, 2), &Device::Cpu)?; let c = a.matmul(&b)?; assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]); let data: Vec<_> = (0..6).map(|i| i as f32).collect(); - let a = Tensor::from_slice(&data, (2, 3), Device::Cpu)?; + let a = Tensor::from_slice(&data, (2, 3), &Device::Cpu)?; let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect(); - let b = Tensor::from_slice(&data, (3, 2), Device::Cpu)?; + let b = Tensor::from_slice(&data, (3, 2), &Device::Cpu)?; let c = a.matmul(&b)?; assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]); let data: Vec<_> = (0..12).map(|i| i as f32).collect(); - let a = Tensor::from_slice(&data, (2, 2, 3), Device::Cpu)?; + let a = Tensor::from_slice(&data, (2, 2, 3), &Device::Cpu)?; let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect(); - let b = Tensor::from_slice(&data, (2, 3, 2), Device::Cpu)?; + let b = Tensor::from_slice(&data, (2, 3, 2), &Device::Cpu)?; let c = a.matmul(&b)?; assert_eq!( c.to_vec3::<f32>()?, |