diff options
Diffstat (limited to 'src/tensor.rs')
-rw-r--r-- | src/tensor.rs | 10 |
1 files changed, 3 insertions, 7 deletions
diff --git a/src/tensor.rs b/src/tensor.rs index 7274c557..571b0399 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -166,12 +166,12 @@ impl Tensor { } pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> { - let shape = array.shape()?.clone(); + let shape = array.shape()?; Self::new_impl(array, shape, device, false) } pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> { - let shape = array.shape()?.clone(); + let shape = array.shape()?; Self::new_impl(array, shape, device, true) } @@ -259,11 +259,7 @@ impl Tensor { let dim = a_dims.len(); - // TODO - // if dim < 2 { - // return Err(SmeltError::InsufficientRank { minimum_rank: 2 }); - // } - if b_dims.len() != dim { + if dim < 2 || b_dims.len() != dim { return Err(Error::ShapeMismatchBinaryOp { lhs: self.shape().clone(), rhs: rhs.shape().clone(), |