summaryrefslogtreecommitdiff
path: root/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/tensor.rs')
-rw-r--r--src/tensor.rs10
1 files changed, 3 insertions, 7 deletions
diff --git a/src/tensor.rs b/src/tensor.rs
index 7274c557..571b0399 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -166,12 +166,12 @@ impl Tensor {
}
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
- let shape = array.shape()?.clone();
+ let shape = array.shape()?;
Self::new_impl(array, shape, device, false)
}
pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
- let shape = array.shape()?.clone();
+ let shape = array.shape()?;
Self::new_impl(array, shape, device, true)
}
@@ -259,11 +259,7 @@ impl Tensor {
let dim = a_dims.len();
- // TODO
- // if dim < 2 {
- // return Err(SmeltError::InsufficientRank { minimum_rank: 2 });
- // }
- if b_dims.len() != dim {
+ if dim < 2 || b_dims.len() != dim {
return Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
rhs: rhs.shape().clone(),