diff options
Diffstat (limited to 'src/storage.rs')
-rw-r--r-- | src/storage.rs | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/src/storage.rs b/src/storage.rs index 573cf945..f1a2d5a0 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -241,4 +241,22 @@ impl Storage { pub(crate) fn sqrt_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> { self.unary_impl::<Sqrt>(shape, stride) } + + pub(crate) fn matmul_impl( + &self, + rhs: &Self, + bmnk: (usize, usize, usize, usize), + lhs_stride: &[usize], + rhs_stride: &[usize], + ) -> Result<Self> { + self.same_device(rhs, "matmul")?; + self.same_dtype(rhs, "matmul")?; + match (self, rhs) { + (Storage::Cpu(storage), Storage::Cpu(rhs_storage)) => { + let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?; + Ok(Self::Cpu(storage)) + } + _ => todo!(), + } + } } |