summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.pre-commit-config.yaml15
-rw-r--r--Cargo.toml3
-rw-r--r--src/cpu_backend.rs138
-rw-r--r--src/device.rs2
-rw-r--r--src/dtype.rs11
-rw-r--r--src/error.rs9
-rw-r--r--src/op.rs1
-rw-r--r--src/storage.rs18
-rw-r--r--src/tensor.rs148
-rw-r--r--tests/grad_tests.rs26
10 files changed, 363 insertions, 8 deletions
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 00000000..83e3e688
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,15 @@
+repos:
+ - repo: https://github.com/Narsil/pre-commit-rust
+ rev: 2eed6366172ef2a5186e8785ec0e67243d7d73d0
+ hooks:
+ - id: fmt
+ name: "Rust (fmt)"
+ - id: clippy
+ name: "Rust (clippy)"
+ args:
+ [
+ "--tests",
+ "--examples",
+ "--",
+ "-Dwarnings",
+ ]
diff --git a/Cargo.toml b/Cargo.toml
index 883664fc..72eb00ce 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -20,12 +20,13 @@ safetensors = "0.3.1"
thiserror = "1"
cudarc = { version = "0.9.9", optional = true }
candle-kernels = { path = "kernels", optional = true }
+gemm = "0.15.4"
[dev-dependencies]
anyhow = "1"
clap = { version = "4.2.4", features = ["derive"] }
rand = "0.8.5"
-tokenizers = "0.13.3"
+tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
[features]
default = []
diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs
index 01c17245..2c708389 100644
--- a/src/cpu_backend.rs
+++ b/src/cpu_backend.rs
@@ -1,5 +1,6 @@
use crate::storage::{BinaryOp, UnaryOp};
use crate::{DType, Error, Result, Shape, StridedIndex};
+use gemm::{gemm, Parallelism};
// TODO: Think about whether we would be better off with a dtype and
// a buffer as an owned slice of bytes.
@@ -17,6 +18,14 @@ impl CpuStorage {
}
}
+ pub fn as_slice<D: crate::WithDType>(&self) -> Result<&[D]> {
+ D::cpu_storage_as_slice(self)
+ }
+
+ pub fn as_mut_slice<D: crate::WithDType>(&mut self) -> Result<&mut [D]> {
+ D::cpu_storage_as_mut_slice(self)
+ }
+
pub(crate) fn affine_impl(
&self,
shape: &Shape,
@@ -97,6 +106,93 @@ impl CpuStorage {
}
}
+ pub(crate) fn matmul_impl(
+ &self,
+ rhs: &Self,
+ (b, m, n, k): (usize, usize, usize, usize),
+ lhs_stride: &[usize],
+ rhs_stride: &[usize],
+ ) -> Result<Self> {
+ let a_skip: usize = m * k;
+ let b_skip: usize = n * k;
+ let c_skip: usize = m * n;
+
+ let rank = lhs_stride.len();
+ let lhs_cs = lhs_stride[rank - 1];
+ let lhs_rs = lhs_stride[rank - 2];
+
+ let rhs_cs = rhs_stride[rank - 1];
+ let rhs_rs = rhs_stride[rank - 2];
+
+ if lhs_stride.len() > 2 {
+ let lhs_batch_stride = &lhs_stride[..rank - 2];
+ let rhs_batch_stride = &rhs_stride[..rank - 2];
+
+ if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
+ // Temporary error before we support abitrary striding.
+ return Err(Error::UnexpectedStriding);
+ }
+ }
+
+ let mut dst = vec![0.0; b * m * n];
+
+ let dst_shape: Shape = (m, n).into();
+ let dst_strides = dst_shape.stride_contiguous();
+ let dst_rs = dst_strides[0];
+ let dst_cs = dst_strides[1];
+
+ for step in 0..b {
+ let lhs_p = &self.as_slice::<f32>()?[step * a_skip..];
+ let rhs_p = &rhs.as_slice::<f32>()?[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ gemm(
+ // m: usize,
+ m,
+ // n: usize,
+ n,
+ // k: usize,
+ k,
+ // dst: *mut T,
+ dst_p.as_mut_ptr(),
+ // dst_cs: isize,
+ dst_cs as isize,
+ // dst_rs: isize,
+ dst_rs as isize,
+ // read_dst: bool,
+ false,
+ // lhs: *const T,
+ lhs_p.as_ptr(),
+ // lhs_cs: isize,
+ lhs_cs as isize,
+ // lhs_rs: isize,
+ lhs_rs as isize,
+ // rhs: *const T,
+ rhs_p.as_ptr(),
+ // rhs_cs: isize,
+ rhs_cs as isize,
+ // rhs_rs: isize,
+ rhs_rs as isize,
+ // alpha: T,
+ 1.0,
+ // beta: T,
+ 1.0,
+ // conj_dst: bool,
+ false,
+ // conj_lhs: bool,
+ false,
+ // conj_rhs: bool,
+ true,
+ // parallelism: Parallelism
+ Parallelism::None,
+ )
+ }
+ }
+
+ let c = Self::F32(dst);
+ Ok(c)
+ }
+
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
let elem_count = shape.elem_count();
match dtype {
@@ -125,3 +221,45 @@ impl CpuStorage {
}
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{Device, Tensor};
+
+ #[test]
+ fn simple_matmul() -> Result<()> {
+ let data = vec![1.0f32, 2.0, 3.0, 4.0];
+ let a = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;
+ let data = vec![1.0f32, 2.0, 3.0, 4.0];
+ let b = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;
+
+ let c = a.matmul(&b)?;
+ assert_eq!(c.to_vec2::<f32>()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]);
+
+ let data = vec![1.0f32, 2.0];
+ let a = Tensor::from_slice(&data, (2, 1), &Device::Cpu)?;
+ let data = vec![3.0f32, 4.0];
+ let b = Tensor::from_slice(&data, (1, 2), &Device::Cpu)?;
+ let c = a.matmul(&b)?;
+ assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
+
+ let data: Vec<_> = (0..6).map(|i| i as f32).collect();
+ let a = Tensor::from_slice(&data, (2, 3), &Device::Cpu)?;
+ let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
+ let b = Tensor::from_slice(&data, (3, 2), &Device::Cpu)?;
+ let c = a.matmul(&b)?;
+ assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
+
+ let data: Vec<_> = (0..12).map(|i| i as f32).collect();
+ let a = Tensor::from_slice(&data, (2, 2, 3), &Device::Cpu)?;
+ let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
+ let b = Tensor::from_slice(&data, (2, 3, 2), &Device::Cpu)?;
+ let c = a.matmul(&b)?;
+ assert_eq!(
+ c.to_vec3::<f32>()?,
+ &[&[&[16., 19.], &[52., 64.]], &[&[214., 235.], &[304., 334.]]]
+ );
+ Ok(())
+ }
+}
diff --git a/src/device.rs b/src/device.rs
index ab7bad26..8acb1338 100644
--- a/src/device.rs
+++ b/src/device.rs
@@ -101,7 +101,7 @@ impl Device {
}
}
- pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Result<Storage> {
+ pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
Device::Cuda(device) => {
diff --git a/src/dtype.rs b/src/dtype.rs
index fd0eaa1b..f6249ff2 100644
--- a/src/dtype.rs
+++ b/src/dtype.rs
@@ -25,6 +25,7 @@ pub trait WithDType: Sized + Copy {
}
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
+ fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]>;
}
macro_rules! with_dtype {
@@ -45,6 +46,16 @@ macro_rules! with_dtype {
}),
}
}
+
+ fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]> {
+ match s {
+ CpuStorage::$dtype(data) => Ok(data),
+ _ => Err(Error::UnexpectedDType {
+ expected: DType::$dtype,
+ got: s.dtype(),
+ }),
+ }
+ }
}
};
}
diff --git a/src/error.rs b/src/error.rs
index 27201cb4..723edaa1 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -12,6 +12,11 @@ pub enum Error {
#[error("the candle crate has not been built with cuda support")]
NotCompiledWithCudaSupport,
+ #[error(
+ "Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}"
+ )]
+ ShapeMismatch { buffer_size: usize, shape: Shape },
+
#[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
ShapeMismatchBinaryOp {
lhs: Shape,
@@ -40,6 +45,10 @@ pub enum Error {
shape: Shape,
},
+ // TODO this is temporary when we support arbitrary matmul
+ #[error("temporary error where matmul doesn't support arbitrary striding")]
+ UnexpectedStriding,
+
#[error(transparent)]
Cuda(#[from] crate::CudaError),
}
diff --git a/src/op.rs b/src/op.rs
index 240ecba3..157ce3b3 100644
--- a/src/op.rs
+++ b/src/op.rs
@@ -5,6 +5,7 @@ pub(crate) enum Op {
Mul(Tensor, Tensor),
Sub(Tensor, Tensor),
Div(Tensor, Tensor),
+ Matmul(Tensor, Tensor),
#[allow(dead_code)] // add is currently unused.
Affine {
diff --git a/src/storage.rs b/src/storage.rs
index 573cf945..f1a2d5a0 100644
--- a/src/storage.rs
+++ b/src/storage.rs
@@ -241,4 +241,22 @@ impl Storage {
pub(crate) fn sqrt_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
self.unary_impl::<Sqrt>(shape, stride)
}
+
+ pub(crate) fn matmul_impl(
+ &self,
+ rhs: &Self,
+ bmnk: (usize, usize, usize, usize),
+ lhs_stride: &[usize],
+ rhs_stride: &[usize],
+ ) -> Result<Self> {
+ self.same_device(rhs, "matmul")?;
+ self.same_dtype(rhs, "matmul")?;
+ match (self, rhs) {
+ (Storage::Cpu(storage), Storage::Cpu(rhs_storage)) => {
+ let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?;
+ Ok(Self::Cpu(storage))
+ }
+ _ => todo!(),
+ }
+ }
}
diff --git a/src/tensor.rs b/src/tensor.rs
index e8e01d5c..09e5d66c 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -147,11 +147,16 @@ impl Tensor {
pub fn new_impl<A: crate::device::NdArray>(
array: A,
+ shape: Shape,
device: &Device,
is_variable: bool,
) -> Result<Self> {
- let shape = array.shape()?;
- let storage = device.tensor(array)?;
+ let n: usize = shape.elem_count();
+ let buffer_size: usize = array.shape()?.elem_count();
+ if buffer_size != n {
+ return Err(Error::ShapeMismatch { buffer_size, shape });
+ }
+ let storage = device.storage(array)?;
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
id: TensorId::new(),
@@ -165,11 +170,29 @@ impl Tensor {
}
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
- Self::new_impl(array, device, false)
+ let shape = array.shape()?;
+ Self::new_impl(array, shape, device, false)
}
pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
- Self::new_impl(array, device, true)
+ let shape = array.shape()?;
+ Self::new_impl(array, shape, device, true)
+ }
+
+ pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
+ array: &[D],
+ shape: S,
+ device: &Device,
+ ) -> Result<Self> {
+ Self::new_impl(array, shape.into(), device, false)
+ }
+
+ pub fn var_from_slice<S: Into<Shape>, D: crate::WithDType>(
+ array: &[D],
+ shape: S,
+ device: &Device,
+ ) -> Result<Self> {
+ Self::new_impl(array, shape.into(), device, true)
}
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
@@ -234,10 +257,65 @@ impl Tensor {
Ok(Self(Arc::new(tensor_)))
}
+ pub fn matmul(&self, rhs: &Self) -> Result<Self> {
+ let a_dims = self.shape().dims();
+ let b_dims = rhs.shape().dims();
+
+ let dim = a_dims.len();
+
+ if dim < 2 || b_dims.len() != dim {
+ return Err(Error::ShapeMismatchBinaryOp {
+ lhs: self.shape().clone(),
+ rhs: rhs.shape().clone(),
+ op: "matmul",
+ });
+ }
+
+ let m = a_dims[dim - 2];
+ let k = a_dims[dim - 1];
+ let k2 = b_dims[dim - 2];
+ let n = b_dims[dim - 1];
+ if k != k2 {
+ return Err(Error::ShapeMismatchBinaryOp {
+ lhs: self.shape().clone(),
+ rhs: rhs.shape().clone(),
+ op: "matmul",
+ });
+ }
+
+ let mut c_shape: Vec<_> = a_dims[..dim - 2].into();
+ c_shape.extend(&[m, n]);
+ let c_shape = Shape(c_shape);
+ let batching: usize = a_dims[..dim - 2].iter().product();
+
+ let storage = self.storage.matmul_impl(
+ &rhs.storage,
+ (batching, m, n, k),
+ self.stride(),
+ rhs.stride(),
+ )?;
+ let tensor_ = Tensor_ {
+ id: TensorId::new(),
+ storage,
+ shape: c_shape.clone(),
+ stride: c_shape.stride_contiguous(),
+ op: Some(Op::Matmul(self.clone(), rhs.clone())),
+ is_variable: false,
+ };
+ Ok(Self(Arc::new(tensor_)))
+ }
+
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
crate::StridedIndex::new(self.dims(), self.stride())
}
+ pub fn as_slice<S: crate::WithDType>(&self) -> Result<&[S]> {
+ match &self.storage {
+ Storage::Cpu(cpu_storage) => S::cpu_storage_as_slice(cpu_storage),
+ Storage::Cuda { .. } => todo!(),
+ }
+ }
+
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
if self.rank() != 1 {
return Err(Error::UnexpectedNumberOfDims {
@@ -279,6 +357,28 @@ impl Tensor {
}
}
+ pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
+ let (dim1, dim2, dim3) = self.shape().r3()?;
+ match &self.storage {
+ Storage::Cpu(cpu_storage) => {
+ let data = S::cpu_storage_as_slice(cpu_storage)?;
+ let mut top_rows = vec![];
+ let mut src_index = self.strided_index();
+ for _idx in 0..dim1 {
+ let mut rows = vec![];
+ for _jdx in 0..dim2 {
+ let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
+ rows.push(row)
+ }
+ top_rows.push(rows);
+ }
+ assert!(src_index.next().is_none());
+ Ok(top_rows)
+ }
+ Storage::Cuda { .. } => todo!(),
+ }
+ }
+
pub fn dtype(&self) -> DType {
self.storage.dtype()
}
@@ -311,6 +411,31 @@ impl Tensor {
self.id
}
+ pub fn t(&self) -> Result<Tensor> {
+ let mut stride = self.stride().to_vec();
+ let mut shape = self.shape().clone();
+ let n = stride.len();
+ if n < 2 {
+ return Err(Error::UnexpectedNumberOfDims {
+ expected: 2,
+ got: n,
+ shape: self.shape().clone(),
+ });
+ }
+ (shape.0[n - 2], shape.0[n - 1]) = (shape.0[n - 1], shape.0[n - 2]);
+ (stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]);
+ let tensor_ = Tensor_ {
+ id: TensorId::new(),
+ storage: self.storage.clone(),
+ shape,
+ stride,
+ // TODO The op should have a backward
+ op: None,
+ is_variable: false,
+ };
+ Ok(Tensor(Arc::new(tensor_)))
+ }
+
pub fn is_contiguous(&self) -> bool {
self.shape.is_contiguous(&self.stride)
}
@@ -340,7 +465,8 @@ impl Tensor {
Op::Add(lhs, rhs)
| Op::Mul(lhs, rhs)
| Op::Sub(lhs, rhs)
- | Op::Div(lhs, rhs) => {
+ | Op::Div(lhs, rhs)
+ | Op::Matmul(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(rhs, nodes, already_seen);
@@ -420,6 +546,18 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
+ Op::Matmul(lhs, rhs) => {
+ // Skipping checks, the op went ok, we can skip
+ // the matmul size checks for now.
+
+ let lhs_grad = grad.matmul(&rhs.t()?)?;
+ let lhs_sum_grad = grads.or_insert(lhs)?;
+ *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
+
+ let rhs_grad = lhs.t()?.matmul(&grad)?;
+ let rhs_sum_grad = grads.or_insert(rhs)?;
+ *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
+ }
Op::Affine { arg, mul, .. } => {
let arg_grad = grad.affine(*mul, 0.)?;
let sum_grad = grads.or_insert(arg)?;
diff --git a/tests/grad_tests.rs b/tests/grad_tests.rs
index 56186e5d..77a32dfe 100644
--- a/tests/grad_tests.rs
+++ b/tests/grad_tests.rs
@@ -1,5 +1,5 @@
use anyhow::{Context, Result};
-use candle::{Device, Tensor};
+use candle::{Device, Shape, Tensor};
#[test]
fn simple_grad() -> Result<()> {
@@ -14,3 +14,27 @@ fn simple_grad() -> Result<()> {
assert_eq!(grad_x.to_vec1::<f32>()?, [11., 7., 13.]);
Ok(())
}
+
+#[test]
+fn matmul_grad() -> Result<()> {
+ let data: Vec<_> = (0..12).map(|i| i as f32).collect();
+ let x = Tensor::var_from_slice(&data, (2, 2, 3), &Device::Cpu)?;
+ let data: Vec<_> = (0..12).map(|i| i as f32).collect();
+ let y = Tensor::var_from_slice(&data, (2, 3, 2), &Device::Cpu)?;
+
+ let c = x.matmul(&y)?;
+ let grads = c.backward()?;
+ let grad_x = grads.get(&x).context("no grad for x")?;
+ let grad_y = grads.get(&y).context("no grad for y")?;
+ assert_eq!(grad_x.shape(), &Shape::from((2, 2, 3)));
+ assert_eq!(grad_y.shape(), &Shape::from((2, 3, 2)));
+ assert_eq!(
+ grad_x.as_slice::<f32>()?,
+ &[1., 5., 9., 1., 5., 9., 13., 17., 21., 13., 17., 21.]
+ );
+ assert_eq!(
+ grad_y.as_slice::<f32>()?,
+ &[3., 3., 5., 5., 7., 7., 15., 15., 17., 17., 19., 19.]
+ );
+ Ok(())
+}