summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml2
-rw-r--r--src/cpu_backend.rs109
-rw-r--r--src/error.rs3
-rw-r--r--src/tensor.rs110
-rw-r--r--tests/grad_tests.rs26
5 files changed, 172 insertions, 78 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 8c601e18..72eb00ce 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -16,11 +16,11 @@ members = [
]
[dependencies]
-ggblas = "0.1.0"
safetensors = "0.3.1"
thiserror = "1"
cudarc = { version = "0.9.9", optional = true }
candle-kernels = { path = "kernels", optional = true }
+gemm = "0.15.4"
[dev-dependencies]
anyhow = "1"
diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs
index c71536ed..0eb4270a 100644
--- a/src/cpu_backend.rs
+++ b/src/cpu_backend.rs
@@ -1,6 +1,6 @@
use crate::storage::{BinaryOp, UnaryOp};
use crate::{DType, Error, Result, Shape, StridedIndex};
-use ggblas::batched_sgemm;
+use gemm::{gemm, Parallelism};
// TODO: Think about whether we would be better off with a dtype and
// a buffer as an owned slice of bytes.
@@ -113,28 +113,83 @@ impl CpuStorage {
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
- println!("rhs {rhs:?}");
- println!("lhs_stride {lhs_stride:?}");
- println!("rhs_stride {rhs_stride:?}");
- // todo!("matmul");
let a_skip: usize = m * k;
let b_skip: usize = n * k;
let c_skip: usize = m * n;
- let mut c = Self::F32(vec![0.0; b * m * n]);
-
- batched_sgemm(
- self.as_slice()?,
- a_skip,
- rhs.as_slice()?,
- b_skip,
- c.as_mut_slice()?,
- c_skip,
- m,
- n,
- k,
- b,
- );
+ let rank = lhs_stride.len();
+ let lhs_cs = lhs_stride[rank - 1];
+ let lhs_rs = lhs_stride[rank - 2];
+
+ let rhs_cs = rhs_stride[rank - 1];
+ let rhs_rs = rhs_stride[rank - 2];
+
+ if lhs_stride.len() > 2 {
+ let lhs_batch_stride = &lhs_stride[..rank - 2];
+ let rhs_batch_stride = &rhs_stride[..rank - 2];
+
+ if lhs_batch_stride != &[a_skip] || rhs_batch_stride != &[b_skip] {
+ // Temporary error before we support abitrary striding.
+ return Err(Error::UnexpectedStriding);
+ }
+ }
+
+ let mut dst = vec![0.0; b * m * n];
+
+ let dst_shape: Shape = (m, n).into();
+ let dst_strides = dst_shape.stride_contiguous();
+ let dst_rs = dst_strides[0];
+ let dst_cs = dst_strides[1];
+
+ for step in 0..b {
+ let lhs_p = &self.as_slice::<f32>()?[step * a_skip..];
+ let rhs_p = &rhs.as_slice::<f32>()?[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ gemm(
+ // m: usize,
+ m,
+ // n: usize,
+ n,
+ // k: usize,
+ k,
+ // dst: *mut T,
+ dst_p.as_mut_ptr(),
+ // dst_cs: isize,
+ dst_cs as isize,
+ // dst_rs: isize,
+ dst_rs as isize,
+ // read_dst: bool,
+ false,
+ // lhs: *const T,
+ lhs_p.as_ptr(),
+ // lhs_cs: isize,
+ lhs_cs as isize,
+ // lhs_rs: isize,
+ lhs_rs as isize,
+ // rhs: *const T,
+ rhs_p.as_ptr(),
+ // rhs_cs: isize,
+ rhs_cs as isize,
+ // rhs_rs: isize,
+ rhs_rs as isize,
+ // alpha: T,
+ 1.0,
+ // beta: T,
+ 1.0,
+ // conj_dst: bool,
+ false,
+ // conj_lhs: bool,
+ false,
+ // conj_rhs: bool,
+ true,
+ // parallelism: Parallelism
+ Parallelism::None,
+ )
+ }
+ }
+
+ let c = Self::F32(dst);
Ok(c)
}
@@ -175,31 +230,31 @@ mod tests {
#[test]
fn simple_matmul() -> Result<()> {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
- let a = Tensor::from_slice(&data, (2, 2), Device::Cpu)?;
+ let a = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;
let data = vec![1.0f32, 2.0, 3.0, 4.0];
- let b = Tensor::from_slice(&data, (2, 2), Device::Cpu)?;
+ let b = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]);
let data = vec![1.0f32, 2.0];
- let a = Tensor::from_slice(&data, (2, 1), Device::Cpu)?;
+ let a = Tensor::from_slice(&data, (2, 1), &Device::Cpu)?;
let data = vec![3.0f32, 4.0];
- let b = Tensor::from_slice(&data, (1, 2), Device::Cpu)?;
+ let b = Tensor::from_slice(&data, (1, 2), &Device::Cpu)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
- let a = Tensor::from_slice(&data, (2, 3), Device::Cpu)?;
+ let a = Tensor::from_slice(&data, (2, 3), &Device::Cpu)?;
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
- let b = Tensor::from_slice(&data, (3, 2), Device::Cpu)?;
+ let b = Tensor::from_slice(&data, (3, 2), &Device::Cpu)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
- let a = Tensor::from_slice(&data, (2, 2, 3), Device::Cpu)?;
+ let a = Tensor::from_slice(&data, (2, 2, 3), &Device::Cpu)?;
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
- let b = Tensor::from_slice(&data, (2, 3, 2), Device::Cpu)?;
+ let b = Tensor::from_slice(&data, (2, 3, 2), &Device::Cpu)?;
let c = a.matmul(&b)?;
assert_eq!(
c.to_vec3::<f32>()?,
diff --git a/src/error.rs b/src/error.rs
index 27201cb4..6f40622c 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -40,6 +40,9 @@ pub enum Error {
shape: Shape,
},
+ #[error("temporary error where matmul doesn't support arbitrary striding")]
+ UnexpectedStriding,
+
#[error(transparent)]
Cuda(#[from] crate::CudaError),
}
diff --git a/src/tensor.rs b/src/tensor.rs
index 7607171c..7274c557 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -147,10 +147,11 @@ impl Tensor {
pub fn new_impl<A: crate::device::NdArray>(
array: A,
+ shape: Shape,
device: &Device,
is_variable: bool,
) -> Result<Self> {
- let shape = array.shape()?;
+ // let shape = array.shape()?;
let storage = device.storage(array)?;
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
@@ -165,31 +166,29 @@ impl Tensor {
}
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
- Self::new_impl(array, device, false)
+ let shape = array.shape()?.clone();
+ Self::new_impl(array, shape, device, false)
}
pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
- Self::new_impl(array, device, true)
+ let shape = array.shape()?.clone();
+ Self::new_impl(array, shape, device, true)
}
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
- a: &[D],
+ array: &[D],
shape: S,
- device: Device,
+ device: &Device,
) -> Result<Self> {
- let shape = shape.into();
- let storage = device.storage(a)?;
- let stride = shape.stride_contiguous();
- let is_variable = false;
- let tensor_ = Tensor_ {
- id: TensorId::new(),
- storage,
- shape,
- stride,
- op: None,
- is_variable,
- };
- Ok(Self(Arc::new(tensor_)))
+ Self::new_impl(array, shape.into(), device, false)
+ }
+
+ pub fn var_from_slice<S: Into<Shape>, D: crate::WithDType>(
+ array: &[D],
+ shape: S,
+ device: &Device,
+ ) -> Result<Self> {
+ Self::new_impl(array, shape.into(), device, true)
}
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
@@ -260,6 +259,7 @@ impl Tensor {
let dim = a_dims.len();
+ // TODO
// if dim < 2 {
// return Err(SmeltError::InsufficientRank { minimum_rank: 2 });
// }
@@ -309,6 +309,13 @@ impl Tensor {
crate::StridedIndex::new(self.dims(), self.stride())
}
+ pub fn as_slice<S: crate::WithDType>(&self) -> Result<&[S]> {
+ match &self.storage {
+ Storage::Cpu(cpu_storage) => S::cpu_storage_as_slice(cpu_storage),
+ Storage::Cuda { .. } => todo!(),
+ }
+ }
+
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
if self.rank() != 1 {
return Err(Error::UnexpectedNumberOfDims {
@@ -404,6 +411,31 @@ impl Tensor {
self.id
}
+ pub fn t(&self) -> Result<Tensor> {
+ let mut stride = self.stride().to_vec();
+ let mut shape = self.shape().clone();
+ let n = stride.len();
+ if n < 2 {
+ return Err(Error::UnexpectedNumberOfDims {
+ expected: 2,
+ got: n,
+ shape: self.shape().clone(),
+ });
+ }
+ (shape.0[n - 2], shape.0[n - 1]) = (shape.0[n - 1], shape.0[n - 2]);
+ (stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]);
+ let tensor_ = Tensor_ {
+ id: TensorId::new(),
+ storage: self.storage.clone(),
+ shape,
+ stride,
+ // TODO The op should have a backward
+ op: None,
+ is_variable: false,
+ };
+ Ok(Tensor(Arc::new(tensor_)))
+ }
+
pub fn is_contiguous(&self) -> bool {
self.shape.is_contiguous(&self.stride)
}
@@ -514,37 +546,17 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
- Op::Matmul(_lhs, _rhs) => {
- // let (m, k) = lhs.shape;
- // let n = rhs.shape.1;
- // let strides = (m, n).strides();
- // Self::matmul(
- // (m, n, k),
- // true,
- // grad_out.as_ptr(),
- // strides,
- // rhs.data.as_ptr(),
- // [rhs.strides[1], rhs.strides[0]],
- // grad_lhs.as_mut_ptr(),
- // lhs.strides,
- // );
- // Self::matmul(
- // (k, m, n),
- // true,
- // lhs.data.as_ptr(),
- // [lhs.strides[1], lhs.strides[0]],
- // grad_out.as_ptr(),
- // strides,
- // grad_rhs.as_mut_ptr(),
- // rhs.strides,
- // );
-
- // let lhs_grad = grad.matmul(rhs)?;
- // let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
- // *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
- // let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
- // let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
- // *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
+ Op::Matmul(lhs, rhs) => {
+ // Skipping checks, the op went ok, we can skip
+ // the matmul size checks for now.
+
+ let lhs_grad = grad.matmul(&rhs.t()?)?;
+ let lhs_sum_grad = grads.or_insert(lhs)?;
+ *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
+
+ let rhs_grad = lhs.t()?.matmul(&grad)?;
+ let rhs_sum_grad = grads.or_insert(rhs)?;
+ *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Affine { arg, mul, .. } => {
let arg_grad = grad.affine(*mul, 0.)?;
diff --git a/tests/grad_tests.rs b/tests/grad_tests.rs
index 56186e5d..77a32dfe 100644
--- a/tests/grad_tests.rs
+++ b/tests/grad_tests.rs
@@ -1,5 +1,5 @@
use anyhow::{Context, Result};
-use candle::{Device, Tensor};
+use candle::{Device, Shape, Tensor};
#[test]
fn simple_grad() -> Result<()> {
@@ -14,3 +14,27 @@ fn simple_grad() -> Result<()> {
assert_eq!(grad_x.to_vec1::<f32>()?, [11., 7., 13.]);
Ok(())
}
+
+#[test]
+fn matmul_grad() -> Result<()> {
+ let data: Vec<_> = (0..12).map(|i| i as f32).collect();
+ let x = Tensor::var_from_slice(&data, (2, 2, 3), &Device::Cpu)?;
+ let data: Vec<_> = (0..12).map(|i| i as f32).collect();
+ let y = Tensor::var_from_slice(&data, (2, 3, 2), &Device::Cpu)?;
+
+ let c = x.matmul(&y)?;
+ let grads = c.backward()?;
+ let grad_x = grads.get(&x).context("no grad for x")?;
+ let grad_y = grads.get(&y).context("no grad for y")?;
+ assert_eq!(grad_x.shape(), &Shape::from((2, 2, 3)));
+ assert_eq!(grad_y.shape(), &Shape::from((2, 3, 2)));
+ assert_eq!(
+ grad_x.as_slice::<f32>()?,
+ &[1., 5., 9., 1., 5., 9., 13., 17., 21., 13., 17., 21.]
+ );
+ assert_eq!(
+ grad_y.as_slice::<f32>()?,
+ &[3., 3., 5., 5., 7., 7., 15., 15., 17., 17., 19., 19.]
+ );
+ Ok(())
+}