diff options
Diffstat (limited to 'src/tensor.rs')
-rw-r--r-- | src/tensor.rs | 148 |
1 files changed, 143 insertions, 5 deletions
diff --git a/src/tensor.rs b/src/tensor.rs index e8e01d5c..09e5d66c 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -147,11 +147,16 @@ impl Tensor { pub fn new_impl<A: crate::device::NdArray>( array: A, + shape: Shape, device: &Device, is_variable: bool, ) -> Result<Self> { - let shape = array.shape()?; - let storage = device.tensor(array)?; + let n: usize = shape.elem_count(); + let buffer_size: usize = array.shape()?.elem_count(); + if buffer_size != n { + return Err(Error::ShapeMismatch { buffer_size, shape }); + } + let storage = device.storage(array)?; let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { id: TensorId::new(), @@ -165,11 +170,29 @@ impl Tensor { } pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> { - Self::new_impl(array, device, false) + let shape = array.shape()?; + Self::new_impl(array, shape, device, false) } pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> { - Self::new_impl(array, device, true) + let shape = array.shape()?; + Self::new_impl(array, shape, device, true) + } + + pub fn from_slice<S: Into<Shape>, D: crate::WithDType>( + array: &[D], + shape: S, + device: &Device, + ) -> Result<Self> { + Self::new_impl(array, shape.into(), device, false) + } + + pub fn var_from_slice<S: Into<Shape>, D: crate::WithDType>( + array: &[D], + shape: S, + device: &Device, + ) -> Result<Self> { + Self::new_impl(array, shape.into(), device, true) } pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> { @@ -234,10 +257,65 @@ impl Tensor { Ok(Self(Arc::new(tensor_))) } + pub fn matmul(&self, rhs: &Self) -> Result<Self> { + let a_dims = self.shape().dims(); + let b_dims = rhs.shape().dims(); + + let dim = a_dims.len(); + + if dim < 2 || b_dims.len() != dim { + return Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "matmul", + }); + } + + let m = a_dims[dim - 2]; + let k = a_dims[dim - 1]; + let k2 = b_dims[dim - 2]; + let n = b_dims[dim - 1]; + if k != k2 { + return Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "matmul", + }); + } + + let mut c_shape: Vec<_> = a_dims[..dim - 2].into(); + c_shape.extend(&[m, n]); + let c_shape = Shape(c_shape); + let batching: usize = a_dims[..dim - 2].iter().product(); + + let storage = self.storage.matmul_impl( + &rhs.storage, + (batching, m, n, k), + self.stride(), + rhs.stride(), + )?; + let tensor_ = Tensor_ { + id: TensorId::new(), + storage, + shape: c_shape.clone(), + stride: c_shape.stride_contiguous(), + op: Some(Op::Matmul(self.clone(), rhs.clone())), + is_variable: false, + }; + Ok(Self(Arc::new(tensor_))) + } + pub(crate) fn strided_index(&self) -> crate::StridedIndex { crate::StridedIndex::new(self.dims(), self.stride()) } + pub fn as_slice<S: crate::WithDType>(&self) -> Result<&[S]> { + match &self.storage { + Storage::Cpu(cpu_storage) => S::cpu_storage_as_slice(cpu_storage), + Storage::Cuda { .. } => todo!(), + } + } + pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> { if self.rank() != 1 { return Err(Error::UnexpectedNumberOfDims { @@ -279,6 +357,28 @@ impl Tensor { } } + pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> { + let (dim1, dim2, dim3) = self.shape().r3()?; + match &self.storage { + Storage::Cpu(cpu_storage) => { + let data = S::cpu_storage_as_slice(cpu_storage)?; + let mut top_rows = vec![]; + let mut src_index = self.strided_index(); + for _idx in 0..dim1 { + let mut rows = vec![]; + for _jdx in 0..dim2 { + let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect(); + rows.push(row) + } + top_rows.push(rows); + } + assert!(src_index.next().is_none()); + Ok(top_rows) + } + Storage::Cuda { .. } => todo!(), + } + } + pub fn dtype(&self) -> DType { self.storage.dtype() } @@ -311,6 +411,31 @@ impl Tensor { self.id } + pub fn t(&self) -> Result<Tensor> { + let mut stride = self.stride().to_vec(); + let mut shape = self.shape().clone(); + let n = stride.len(); + if n < 2 { + return Err(Error::UnexpectedNumberOfDims { + expected: 2, + got: n, + shape: self.shape().clone(), + }); + } + (shape.0[n - 2], shape.0[n - 1]) = (shape.0[n - 1], shape.0[n - 2]); + (stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + shape, + stride, + // TODO The op should have a backward + op: None, + is_variable: false, + }; + Ok(Tensor(Arc::new(tensor_))) + } + pub fn is_contiguous(&self) -> bool { self.shape.is_contiguous(&self.stride) } @@ -340,7 +465,8 @@ impl Tensor { Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) | Op::Sub(lhs, rhs) - | Op::Div(lhs, rhs) => { + | Op::Div(lhs, rhs) + | Op::Matmul(lhs, rhs) => { let (tg, nodes) = walk(lhs, nodes, already_seen); track_grad |= tg; let (tg, nodes) = walk(rhs, nodes, already_seen); @@ -420,6 +546,18 @@ impl Tensor { let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } + Op::Matmul(lhs, rhs) => { + // Skipping checks, the op went ok, we can skip + // the matmul size checks for now. + + let lhs_grad = grad.matmul(&rhs.t()?)?; + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + + let rhs_grad = lhs.t()?.matmul(&grad)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + } Op::Affine { arg, mul, .. } => { let arg_grad = grad.affine(*mul, 0.)?; let sum_grad = grads.or_insert(arg)?; |