summaryrefslogtreecommitdiff
path: root/src/storage.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/storage.rs')
-rw-r--r--src/storage.rs18
1 files changed, 18 insertions, 0 deletions
diff --git a/src/storage.rs b/src/storage.rs
index 573cf945..f1a2d5a0 100644
--- a/src/storage.rs
+++ b/src/storage.rs
@@ -241,4 +241,22 @@ impl Storage {
pub(crate) fn sqrt_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
self.unary_impl::<Sqrt>(shape, stride)
}
+
+ pub(crate) fn matmul_impl(
+ &self,
+ rhs: &Self,
+ bmnk: (usize, usize, usize, usize),
+ lhs_stride: &[usize],
+ rhs_stride: &[usize],
+ ) -> Result<Self> {
+ self.same_device(rhs, "matmul")?;
+ self.same_dtype(rhs, "matmul")?;
+ match (self, rhs) {
+ (Storage::Cpu(storage), Storage::Cpu(rhs_storage)) => {
+ let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?;
+ Ok(Self::Cpu(storage))
+ }
+ _ => todo!(),
+ }
+ }
}