diff options
Diffstat (limited to 'src/tensor.rs')
-rw-r--r-- | src/tensor.rs | 110 |
1 files changed, 61 insertions, 49 deletions
diff --git a/src/tensor.rs b/src/tensor.rs index 7607171c..7274c557 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -147,10 +147,11 @@ impl Tensor { pub fn new_impl<A: crate::device::NdArray>( array: A, + shape: Shape, device: &Device, is_variable: bool, ) -> Result<Self> { - let shape = array.shape()?; + // let shape = array.shape()?; let storage = device.storage(array)?; let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { @@ -165,31 +166,29 @@ impl Tensor { } pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> { - Self::new_impl(array, device, false) + let shape = array.shape()?.clone(); + Self::new_impl(array, shape, device, false) } pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> { - Self::new_impl(array, device, true) + let shape = array.shape()?.clone(); + Self::new_impl(array, shape, device, true) } pub fn from_slice<S: Into<Shape>, D: crate::WithDType>( - a: &[D], + array: &[D], shape: S, - device: Device, + device: &Device, ) -> Result<Self> { - let shape = shape.into(); - let storage = device.storage(a)?; - let stride = shape.stride_contiguous(); - let is_variable = false; - let tensor_ = Tensor_ { - id: TensorId::new(), - storage, - shape, - stride, - op: None, - is_variable, - }; - Ok(Self(Arc::new(tensor_))) + Self::new_impl(array, shape.into(), device, false) + } + + pub fn var_from_slice<S: Into<Shape>, D: crate::WithDType>( + array: &[D], + shape: S, + device: &Device, + ) -> Result<Self> { + Self::new_impl(array, shape.into(), device, true) } pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> { @@ -260,6 +259,7 @@ impl Tensor { let dim = a_dims.len(); + // TODO // if dim < 2 { // return Err(SmeltError::InsufficientRank { minimum_rank: 2 }); // } @@ -309,6 +309,13 @@ impl Tensor { crate::StridedIndex::new(self.dims(), self.stride()) } + pub fn as_slice<S: crate::WithDType>(&self) -> Result<&[S]> { + match &self.storage { + Storage::Cpu(cpu_storage) => S::cpu_storage_as_slice(cpu_storage), + Storage::Cuda { .. } => todo!(), + } + } + pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> { if self.rank() != 1 { return Err(Error::UnexpectedNumberOfDims { @@ -404,6 +411,31 @@ impl Tensor { self.id } + pub fn t(&self) -> Result<Tensor> { + let mut stride = self.stride().to_vec(); + let mut shape = self.shape().clone(); + let n = stride.len(); + if n < 2 { + return Err(Error::UnexpectedNumberOfDims { + expected: 2, + got: n, + shape: self.shape().clone(), + }); + } + (shape.0[n - 2], shape.0[n - 1]) = (shape.0[n - 1], shape.0[n - 2]); + (stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + shape, + stride, + // TODO The op should have a backward + op: None, + is_variable: false, + }; + Ok(Tensor(Arc::new(tensor_))) + } + pub fn is_contiguous(&self) -> bool { self.shape.is_contiguous(&self.stride) } @@ -514,37 +546,17 @@ impl Tensor { let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } - Op::Matmul(_lhs, _rhs) => { - // let (m, k) = lhs.shape; - // let n = rhs.shape.1; - // let strides = (m, n).strides(); - // Self::matmul( - // (m, n, k), - // true, - // grad_out.as_ptr(), - // strides, - // rhs.data.as_ptr(), - // [rhs.strides[1], rhs.strides[0]], - // grad_lhs.as_mut_ptr(), - // lhs.strides, - // ); - // Self::matmul( - // (k, m, n), - // true, - // lhs.data.as_ptr(), - // [lhs.strides[1], lhs.strides[0]], - // grad_out.as_ptr(), - // strides, - // grad_rhs.as_mut_ptr(), - // rhs.strides, - // ); - - // let lhs_grad = grad.matmul(rhs)?; - // let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like()); - // *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; - // let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?; - // let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like()); - // *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + Op::Matmul(lhs, rhs) => { + // Skipping checks, the op went ok, we can skip + // the matmul size checks for now. + + let lhs_grad = grad.matmul(&rhs.t()?)?; + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + + let rhs_grad = lhs.t()?.matmul(&grad)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } Op::Affine { arg, mul, .. } => { let arg_grad = grad.affine(*mul, 0.)?; |