summaryrefslogtreecommitdiff
path: root/src/cpu_backend.rs
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-06-21 16:52:35 +0200
committerNicolas Patry <patry.nicolas@protonmail.com>2023-06-22 12:25:58 +0200
commitce977b489e4863cf5c53495a093c0efef2d41013 (patch)
tree350fe63201c1d717c0b2e9fab460ae9ec2ec53f1 /src/cpu_backend.rs
parent87a37b3bf3b6fd5034269c10c21c8f91e0223eb0 (diff)
downloadcandle-ce977b489e4863cf5c53495a093c0efef2d41013.tar.gz
candle-ce977b489e4863cf5c53495a093c0efef2d41013.tar.bz2
candle-ce977b489e4863cf5c53495a093c0efef2d41013.zip
Adding matmul?
Diffstat (limited to 'src/cpu_backend.rs')
-rw-r--r--src/cpu_backend.rs83
1 files changed, 83 insertions, 0 deletions
diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs
index 01c17245..c71536ed 100644
--- a/src/cpu_backend.rs
+++ b/src/cpu_backend.rs
@@ -1,5 +1,6 @@
use crate::storage::{BinaryOp, UnaryOp};
use crate::{DType, Error, Result, Shape, StridedIndex};
+use ggblas::batched_sgemm;
// TODO: Think about whether we would be better off with a dtype and
// a buffer as an owned slice of bytes.
@@ -17,6 +18,14 @@ impl CpuStorage {
}
}
+ pub fn as_slice<D: crate::WithDType>(&self) -> Result<&[D]> {
+ D::cpu_storage_as_slice(self)
+ }
+
+ pub fn as_mut_slice<D: crate::WithDType>(&mut self) -> Result<&mut [D]> {
+ D::cpu_storage_as_mut_slice(self)
+ }
+
pub(crate) fn affine_impl(
&self,
shape: &Shape,
@@ -97,6 +106,38 @@ impl CpuStorage {
}
}
+ pub(crate) fn matmul_impl(
+ &self,
+ rhs: &Self,
+ (b, m, n, k): (usize, usize, usize, usize),
+ lhs_stride: &[usize],
+ rhs_stride: &[usize],
+ ) -> Result<Self> {
+ println!("rhs {rhs:?}");
+ println!("lhs_stride {lhs_stride:?}");
+ println!("rhs_stride {rhs_stride:?}");
+ // todo!("matmul");
+ let a_skip: usize = m * k;
+ let b_skip: usize = n * k;
+ let c_skip: usize = m * n;
+
+ let mut c = Self::F32(vec![0.0; b * m * n]);
+
+ batched_sgemm(
+ self.as_slice()?,
+ a_skip,
+ rhs.as_slice()?,
+ b_skip,
+ c.as_mut_slice()?,
+ c_skip,
+ m,
+ n,
+ k,
+ b,
+ );
+ Ok(c)
+ }
+
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
let elem_count = shape.elem_count();
match dtype {
@@ -125,3 +166,45 @@ impl CpuStorage {
}
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{Device, Tensor};
+
+ #[test]
+ fn simple_matmul() -> Result<()> {
+ let data = vec![1.0f32, 2.0, 3.0, 4.0];
+ let a = Tensor::from_slice(&data, (2, 2), Device::Cpu)?;
+ let data = vec![1.0f32, 2.0, 3.0, 4.0];
+ let b = Tensor::from_slice(&data, (2, 2), Device::Cpu)?;
+
+ let c = a.matmul(&b)?;
+ assert_eq!(c.to_vec2::<f32>()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]);
+
+ let data = vec![1.0f32, 2.0];
+ let a = Tensor::from_slice(&data, (2, 1), Device::Cpu)?;
+ let data = vec![3.0f32, 4.0];
+ let b = Tensor::from_slice(&data, (1, 2), Device::Cpu)?;
+ let c = a.matmul(&b)?;
+ assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
+
+ let data: Vec<_> = (0..6).map(|i| i as f32).collect();
+ let a = Tensor::from_slice(&data, (2, 3), Device::Cpu)?;
+ let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
+ let b = Tensor::from_slice(&data, (3, 2), Device::Cpu)?;
+ let c = a.matmul(&b)?;
+ assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
+
+ let data: Vec<_> = (0..12).map(|i| i as f32).collect();
+ let a = Tensor::from_slice(&data, (2, 2, 3), Device::Cpu)?;
+ let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
+ let b = Tensor::from_slice(&data, (2, 3, 2), Device::Cpu)?;
+ let c = a.matmul(&b)?;
+ assert_eq!(
+ c.to_vec3::<f32>()?,
+ &[&[&[16., 19.], &[52., 64.]], &[&[214., 235.], &[304., 334.]]]
+ );
+ Ok(())
+ }
+}