diff options
Diffstat (limited to 'src/cpu_backend.rs')
-rw-r--r-- | src/cpu_backend.rs | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 01c17245..c71536ed 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -1,5 +1,6 @@ use crate::storage::{BinaryOp, UnaryOp}; use crate::{DType, Error, Result, Shape, StridedIndex}; +use ggblas::batched_sgemm; // TODO: Think about whether we would be better off with a dtype and // a buffer as an owned slice of bytes. @@ -17,6 +18,14 @@ impl CpuStorage { } } + pub fn as_slice<D: crate::WithDType>(&self) -> Result<&[D]> { + D::cpu_storage_as_slice(self) + } + + pub fn as_mut_slice<D: crate::WithDType>(&mut self) -> Result<&mut [D]> { + D::cpu_storage_as_mut_slice(self) + } + pub(crate) fn affine_impl( &self, shape: &Shape, @@ -97,6 +106,38 @@ impl CpuStorage { } } + pub(crate) fn matmul_impl( + &self, + rhs: &Self, + (b, m, n, k): (usize, usize, usize, usize), + lhs_stride: &[usize], + rhs_stride: &[usize], + ) -> Result<Self> { + println!("rhs {rhs:?}"); + println!("lhs_stride {lhs_stride:?}"); + println!("rhs_stride {rhs_stride:?}"); + // todo!("matmul"); + let a_skip: usize = m * k; + let b_skip: usize = n * k; + let c_skip: usize = m * n; + + let mut c = Self::F32(vec![0.0; b * m * n]); + + batched_sgemm( + self.as_slice()?, + a_skip, + rhs.as_slice()?, + b_skip, + c.as_mut_slice()?, + c_skip, + m, + n, + k, + b, + ); + Ok(c) + } + pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self { let elem_count = shape.elem_count(); match dtype { @@ -125,3 +166,45 @@ impl CpuStorage { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Device, Tensor}; + + #[test] + fn simple_matmul() -> Result<()> { + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let a = Tensor::from_slice(&data, (2, 2), Device::Cpu)?; + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = Tensor::from_slice(&data, (2, 2), Device::Cpu)?; + + let c = a.matmul(&b)?; + assert_eq!(c.to_vec2::<f32>()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]); + + let data = vec![1.0f32, 2.0]; + let a = Tensor::from_slice(&data, (2, 1), Device::Cpu)?; + let data = vec![3.0f32, 4.0]; + let b = Tensor::from_slice(&data, (1, 2), Device::Cpu)?; + let c = a.matmul(&b)?; + assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]); + + let data: Vec<_> = (0..6).map(|i| i as f32).collect(); + let a = Tensor::from_slice(&data, (2, 3), Device::Cpu)?; + let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect(); + let b = Tensor::from_slice(&data, (3, 2), Device::Cpu)?; + let c = a.matmul(&b)?; + assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]); + + let data: Vec<_> = (0..12).map(|i| i as f32).collect(); + let a = Tensor::from_slice(&data, (2, 2, 3), Device::Cpu)?; + let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect(); + let b = Tensor::from_slice(&data, (2, 3, 2), Device::Cpu)?; + let c = a.matmul(&b)?; + assert_eq!( + c.to_vec3::<f32>()?, + &[&[&[16., 19.], &[52., 64.]], &[&[214., 235.], &[304., 334.]]] + ); + Ok(()) + } +} |