summaryrefslogtreecommitdiff
path: root/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/tensor.rs')
-rw-r--r--src/tensor.rs148
1 files changed, 143 insertions, 5 deletions
diff --git a/src/tensor.rs b/src/tensor.rs
index e8e01d5c..09e5d66c 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -147,11 +147,16 @@ impl Tensor {
pub fn new_impl<A: crate::device::NdArray>(
array: A,
+ shape: Shape,
device: &Device,
is_variable: bool,
) -> Result<Self> {
- let shape = array.shape()?;
- let storage = device.tensor(array)?;
+ let n: usize = shape.elem_count();
+ let buffer_size: usize = array.shape()?.elem_count();
+ if buffer_size != n {
+ return Err(Error::ShapeMismatch { buffer_size, shape });
+ }
+ let storage = device.storage(array)?;
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {
id: TensorId::new(),
@@ -165,11 +170,29 @@ impl Tensor {
}
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
- Self::new_impl(array, device, false)
+ let shape = array.shape()?;
+ Self::new_impl(array, shape, device, false)
}
pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
- Self::new_impl(array, device, true)
+ let shape = array.shape()?;
+ Self::new_impl(array, shape, device, true)
+ }
+
+ pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
+ array: &[D],
+ shape: S,
+ device: &Device,
+ ) -> Result<Self> {
+ Self::new_impl(array, shape.into(), device, false)
+ }
+
+ pub fn var_from_slice<S: Into<Shape>, D: crate::WithDType>(
+ array: &[D],
+ shape: S,
+ device: &Device,
+ ) -> Result<Self> {
+ Self::new_impl(array, shape.into(), device, true)
}
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
@@ -234,10 +257,65 @@ impl Tensor {
Ok(Self(Arc::new(tensor_)))
}
+ pub fn matmul(&self, rhs: &Self) -> Result<Self> {
+ let a_dims = self.shape().dims();
+ let b_dims = rhs.shape().dims();
+
+ let dim = a_dims.len();
+
+ if dim < 2 || b_dims.len() != dim {
+ return Err(Error::ShapeMismatchBinaryOp {
+ lhs: self.shape().clone(),
+ rhs: rhs.shape().clone(),
+ op: "matmul",
+ });
+ }
+
+ let m = a_dims[dim - 2];
+ let k = a_dims[dim - 1];
+ let k2 = b_dims[dim - 2];
+ let n = b_dims[dim - 1];
+ if k != k2 {
+ return Err(Error::ShapeMismatchBinaryOp {
+ lhs: self.shape().clone(),
+ rhs: rhs.shape().clone(),
+ op: "matmul",
+ });
+ }
+
+ let mut c_shape: Vec<_> = a_dims[..dim - 2].into();
+ c_shape.extend(&[m, n]);
+ let c_shape = Shape(c_shape);
+ let batching: usize = a_dims[..dim - 2].iter().product();
+
+ let storage = self.storage.matmul_impl(
+ &rhs.storage,
+ (batching, m, n, k),
+ self.stride(),
+ rhs.stride(),
+ )?;
+ let tensor_ = Tensor_ {
+ id: TensorId::new(),
+ storage,
+ shape: c_shape.clone(),
+ stride: c_shape.stride_contiguous(),
+ op: Some(Op::Matmul(self.clone(), rhs.clone())),
+ is_variable: false,
+ };
+ Ok(Self(Arc::new(tensor_)))
+ }
+
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
crate::StridedIndex::new(self.dims(), self.stride())
}
+ pub fn as_slice<S: crate::WithDType>(&self) -> Result<&[S]> {
+ match &self.storage {
+ Storage::Cpu(cpu_storage) => S::cpu_storage_as_slice(cpu_storage),
+ Storage::Cuda { .. } => todo!(),
+ }
+ }
+
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
if self.rank() != 1 {
return Err(Error::UnexpectedNumberOfDims {
@@ -279,6 +357,28 @@ impl Tensor {
}
}
+ pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
+ let (dim1, dim2, dim3) = self.shape().r3()?;
+ match &self.storage {
+ Storage::Cpu(cpu_storage) => {
+ let data = S::cpu_storage_as_slice(cpu_storage)?;
+ let mut top_rows = vec![];
+ let mut src_index = self.strided_index();
+ for _idx in 0..dim1 {
+ let mut rows = vec![];
+ for _jdx in 0..dim2 {
+ let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
+ rows.push(row)
+ }
+ top_rows.push(rows);
+ }
+ assert!(src_index.next().is_none());
+ Ok(top_rows)
+ }
+ Storage::Cuda { .. } => todo!(),
+ }
+ }
+
pub fn dtype(&self) -> DType {
self.storage.dtype()
}
@@ -311,6 +411,31 @@ impl Tensor {
self.id
}
+ pub fn t(&self) -> Result<Tensor> {
+ let mut stride = self.stride().to_vec();
+ let mut shape = self.shape().clone();
+ let n = stride.len();
+ if n < 2 {
+ return Err(Error::UnexpectedNumberOfDims {
+ expected: 2,
+ got: n,
+ shape: self.shape().clone(),
+ });
+ }
+ (shape.0[n - 2], shape.0[n - 1]) = (shape.0[n - 1], shape.0[n - 2]);
+ (stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]);
+ let tensor_ = Tensor_ {
+ id: TensorId::new(),
+ storage: self.storage.clone(),
+ shape,
+ stride,
+ // TODO The op should have a backward
+ op: None,
+ is_variable: false,
+ };
+ Ok(Tensor(Arc::new(tensor_)))
+ }
+
pub fn is_contiguous(&self) -> bool {
self.shape.is_contiguous(&self.stride)
}
@@ -340,7 +465,8 @@ impl Tensor {
Op::Add(lhs, rhs)
| Op::Mul(lhs, rhs)
| Op::Sub(lhs, rhs)
- | Op::Div(lhs, rhs) => {
+ | Op::Div(lhs, rhs)
+ | Op::Matmul(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(rhs, nodes, already_seen);
@@ -420,6 +546,18 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
+ Op::Matmul(lhs, rhs) => {
+ // Skipping checks, the op went ok, we can skip
+ // the matmul size checks for now.
+
+ let lhs_grad = grad.matmul(&rhs.t()?)?;
+ let lhs_sum_grad = grads.or_insert(lhs)?;
+ *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
+
+ let rhs_grad = lhs.t()?.matmul(&grad)?;
+ let rhs_sum_grad = grads.or_insert(rhs)?;
+ *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
+ }
Op::Affine { arg, mul, .. } => {
let arg_grad = grad.affine(*mul, 0.)?;
let sum_grad = grads.or_insert(arg)?;