summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/cpu_backend.rs83
-rw-r--r--src/device.rs2
-rw-r--r--src/dtype.rs11
-rw-r--r--src/op.rs1
-rw-r--r--src/storage.rs18
-rw-r--r--src/tensor.rs130
6 files changed, 242 insertions, 3 deletions
diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs
index 01c17245..c71536ed 100644
--- a/src/cpu_backend.rs
+++ b/src/cpu_backend.rs
@@ -1,5 +1,6 @@
use crate::storage::{BinaryOp, UnaryOp};
use crate::{DType, Error, Result, Shape, StridedIndex};
+use ggblas::batched_sgemm;
// TODO: Think about whether we would be better off with a dtype and
// a buffer as an owned slice of bytes.
@@ -17,6 +18,14 @@ impl CpuStorage {
}
}
+ pub fn as_slice<D: crate::WithDType>(&self) -> Result<&[D]> {
+ D::cpu_storage_as_slice(self)
+ }
+
+ pub fn as_mut_slice<D: crate::WithDType>(&mut self) -> Result<&mut [D]> {
+ D::cpu_storage_as_mut_slice(self)
+ }
+
pub(crate) fn affine_impl(
&self,
shape: &Shape,
@@ -97,6 +106,38 @@ impl CpuStorage {
}
}
+ pub(crate) fn matmul_impl(
+ &self,
+ rhs: &Self,
+ (b, m, n, k): (usize, usize, usize, usize),
+ lhs_stride: &[usize],
+ rhs_stride: &[usize],
+ ) -> Result<Self> {
+ println!("rhs {rhs:?}");
+ println!("lhs_stride {lhs_stride:?}");
+ println!("rhs_stride {rhs_stride:?}");
+ // todo!("matmul");
+ let a_skip: usize = m * k;
+ let b_skip: usize = n * k;
+ let c_skip: usize = m * n;
+
+ let mut c = Self::F32(vec![0.0; b * m * n]);
+
+ batched_sgemm(
+ self.as_slice()?,
+ a_skip,
+ rhs.as_slice()?,
+ b_skip,
+ c.as_mut_slice()?,
+ c_skip,
+ m,
+ n,
+ k,
+ b,
+ );
+ Ok(c)
+ }
+
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
let elem_count = shape.elem_count();
match dtype {
@@ -125,3 +166,45 @@ impl CpuStorage {
}
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{Device, Tensor};
+
+ #[test]
+ fn simple_matmul() -> Result<()> {
+ let data = vec![1.0f32, 2.0, 3.0, 4.0];
+ let a = Tensor::from_slice(&data, (2, 2), Device::Cpu)?;
+ let data = vec![1.0f32, 2.0, 3.0, 4.0];
+ let b = Tensor::from_slice(&data, (2, 2), Device::Cpu)?;
+
+ let c = a.matmul(&b)?;
+ assert_eq!(c.to_vec2::<f32>()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]);
+
+ let data = vec![1.0f32, 2.0];
+ let a = Tensor::from_slice(&data, (2, 1), Device::Cpu)?;
+ let data = vec![3.0f32, 4.0];
+ let b = Tensor::from_slice(&data, (1, 2), Device::Cpu)?;
+ let c = a.matmul(&b)?;
+ assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
+
+ let data: Vec<_> = (0..6).map(|i| i as f32).collect();
+ let a = Tensor::from_slice(&data, (2, 3), Device::Cpu)?;
+ let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
+ let b = Tensor::from_slice(&data, (3, 2), Device::Cpu)?;
+ let c = a.matmul(&b)?;
+ assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
+
+ let data: Vec<_> = (0..12).map(|i| i as f32).collect();
+ let a = Tensor::from_slice(&data, (2, 2, 3), Device::Cpu)?;
+ let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
+ let b = Tensor::from_slice(&data, (2, 3, 2), Device::Cpu)?;
+ let c = a.matmul(&b)?;
+ assert_eq!(
+ c.to_vec3::<f32>()?,
+ &[&[&[16., 19.], &[52., 64.]], &[&[214., 235.], &[304., 334.]]]
+ );
+ Ok(())
+ }
+}
diff --git a/src/device.rs b/src/device.rs
index ab7bad26..8acb1338 100644
--- a/src/device.rs
+++ b/src/device.rs
@@ -101,7 +101,7 @@ impl Device {
}
}
- pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Result<Storage> {
+ pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
Device::Cuda(device) => {
diff --git a/src/dtype.rs b/src/dtype.rs
index fd0eaa1b..f6249ff2 100644
--- a/src/dtype.rs
+++ b/src/dtype.rs
@@ -25,6 +25,7 @@ pub trait WithDType: Sized + Copy {
}
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
+ fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]>;
}
macro_rules! with_dtype {
@@ -45,6 +46,16 @@ macro_rules! with_dtype {
}),
}
}
+
+ fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]> {
+ match s {
+ CpuStorage::$dtype(data) => Ok(data),
+ _ => Err(Error::UnexpectedDType {
+ expected: DType::$dtype,
+ got: s.dtype(),
+ }),
+ }
+ }
}
};
}
diff --git a/src/op.rs b/src/op.rs
index 240ecba3..157ce3b3 100644
--- a/src/op.rs
+++ b/src/op.rs
@@ -5,6 +5,7 @@ pub(crate) enum Op {
Mul(Tensor, Tensor),
Sub(Tensor, Tensor),
Div(Tensor, Tensor),
+ Matmul(Tensor, Tensor),
#[allow(dead_code)] // add is currently unused.
Affine {
diff --git a/src/storage.rs b/src/storage.rs
index 573cf945..f1a2d5a0 100644
--- a/src/storage.rs
+++ b/src/storage.rs
@@ -241,4 +241,22 @@ impl Storage {
pub(crate) fn sqrt_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
self.unary_impl::<Sqrt>(shape, stride)
}
+
+ pub(crate) fn matmul_impl(
+ &self,
+ rhs: &Self,
+ bmnk: (usize, usize, usize, usize),
+ lhs_stride: &[usize],
+ rhs_stride: &[usize],
+ ) -> Result<Self> {
+ self.same_device(rhs, "matmul")?;
+ self.same_dtype(rhs, "matmul")?;
+ match (self, rhs) {
+ (Storage::Cpu(storage), Storage::Cpu(rhs_storage)) => {
+ let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?;
+ Ok(Self::Cpu(storage))
+ }
+ _ => todo!(),
+ }
+ }
}
diff --git a/src/tensor.rs b/src/tensor.rs
index e8e01d5c..e55050c6 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -151,7 +151,7 @@ impl Tensor {
is_variable: bool,
) -> Result<Self> {
let shape = array.shape()?;
- let storage = device.tensor(array)?;
+ let storage = device.storage(array)?;
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
id: TensorId::new(),
@@ -172,6 +172,26 @@ impl Tensor {
Self::new_impl(array, device, true)
}
+ pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
+ a: &[D],
+ shape: S,
+ device: Device,
+ ) -> Result<Self> {
+ let shape = shape.into();
+ let storage = device.storage(a);
+ let stride = shape.stride_contiguous();
+ let is_variable = false;
+ let tensor_ = Tensor_ {
+ id: TensorId::new(),
+ storage,
+ shape,
+ stride,
+ op: None,
+ is_variable,
+ };
+ Ok(Self(Arc::new(tensor_)))
+ }
+
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
let lhs = self.shape();
let rhs = rhs.shape();
@@ -234,6 +254,57 @@ impl Tensor {
Ok(Self(Arc::new(tensor_)))
}
+ pub fn matmul(&self, rhs: &Self) -> Result<Self> {
+ let a_dims = self.shape().dims();
+ let b_dims = rhs.shape().dims();
+
+ let dim = a_dims.len();
+
+ // if dim < 2 {
+ // return Err(SmeltError::InsufficientRank { minimum_rank: 2 });
+ // }
+ if b_dims.len() != dim {
+ return Err(Error::ShapeMismatchBinaryOp {
+ lhs: self.shape().clone(),
+ rhs: rhs.shape().clone(),
+ op: "matmul",
+ });
+ }
+
+ let m = a_dims[dim - 2];
+ let k = a_dims[dim - 1];
+ let k2 = b_dims[dim - 2];
+ let n = b_dims[dim - 1];
+ if k != k2 {
+ return Err(Error::ShapeMismatchBinaryOp {
+ lhs: self.shape().clone(),
+ rhs: rhs.shape().clone(),
+ op: "matmul",
+ });
+ }
+
+ let mut c_shape: Vec<_> = a_dims[..dim - 2].into();
+ c_shape.extend(&[m, n]);
+ let c_shape: Shape = Shape(c_shape);
+ let batching: usize = a_dims[..dim - 2].iter().product();
+
+ let storage = self.storage.matmul_impl(
+ &rhs.storage,
+ (batching, m, n, k),
+ self.stride(),
+ rhs.stride(),
+ )?;
+ let tensor_ = Tensor_ {
+ id: TensorId::new(),
+ storage,
+ shape: c_shape.clone(),
+ stride: c_shape.stride_contiguous(),
+ op: Some(Op::Matmul(self.clone(), rhs.clone())),
+ is_variable: false,
+ };
+ Ok(Self(Arc::new(tensor_)))
+ }
+
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
crate::StridedIndex::new(self.dims(), self.stride())
}
@@ -279,6 +350,28 @@ impl Tensor {
}
}
+ pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
+ let (dim1, dim2, dim3) = self.shape().r3()?;
+ match &self.storage {
+ Storage::Cpu(cpu_storage) => {
+ let data = S::cpu_storage_as_slice(cpu_storage)?;
+ let mut top_rows = vec![];
+ let mut src_index = self.strided_index();
+ for _idx in 0..dim1 {
+ let mut rows = vec![];
+ for _jdx in 0..dim2 {
+ let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
+ rows.push(row)
+ }
+ top_rows.push(rows);
+ }
+ assert!(src_index.next().is_none());
+ Ok(top_rows)
+ }
+ Storage::Cuda { .. } => todo!(),
+ }
+ }
+
pub fn dtype(&self) -> DType {
self.storage.dtype()
}
@@ -340,7 +433,8 @@ impl Tensor {
Op::Add(lhs, rhs)
| Op::Mul(lhs, rhs)
| Op::Sub(lhs, rhs)
- | Op::Div(lhs, rhs) => {
+ | Op::Div(lhs, rhs)
+ | Op::Matmul(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(rhs, nodes, already_seen);
@@ -420,6 +514,38 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
+ Op::Matmul(lhs, rhs) => {
+ // let (m, k) = lhs.shape;
+ // let n = rhs.shape.1;
+ // let strides = (m, n).strides();
+ // Self::matmul(
+ // (m, n, k),
+ // true,
+ // grad_out.as_ptr(),
+ // strides,
+ // rhs.data.as_ptr(),
+ // [rhs.strides[1], rhs.strides[0]],
+ // grad_lhs.as_mut_ptr(),
+ // lhs.strides,
+ // );
+ // Self::matmul(
+ // (k, m, n),
+ // true,
+ // lhs.data.as_ptr(),
+ // [lhs.strides[1], lhs.strides[0]],
+ // grad_out.as_ptr(),
+ // strides,
+ // grad_rhs.as_mut_ptr(),
+ // rhs.strides,
+ // );
+
+ let lhs_grad = grad.matmul(rhs)?;
+ let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
+ *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
+ let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
+ let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
+ *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
+ }
Op::Affine { arg, mul, .. } => {
let arg_grad = grad.affine(*mul, 0.)?;
let sum_grad = grads.or_insert(arg)?;