summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/cpu_backend.rs2
-rw-r--r--src/tensor.rs10
2 files changed, 4 insertions, 8 deletions
diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs
index 0eb4270a..2c708389 100644
--- a/src/cpu_backend.rs
+++ b/src/cpu_backend.rs
@@ -128,7 +128,7 @@ impl CpuStorage {
let lhs_batch_stride = &lhs_stride[..rank - 2];
let rhs_batch_stride = &rhs_stride[..rank - 2];
- if lhs_batch_stride != &[a_skip] || rhs_batch_stride != &[b_skip] {
+ if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
// Temporary error before we support abitrary striding.
return Err(Error::UnexpectedStriding);
}
diff --git a/src/tensor.rs b/src/tensor.rs
index 7274c557..571b0399 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -166,12 +166,12 @@ impl Tensor {
}
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
- let shape = array.shape()?.clone();
+ let shape = array.shape()?;
Self::new_impl(array, shape, device, false)
}
pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
- let shape = array.shape()?.clone();
+ let shape = array.shape()?;
Self::new_impl(array, shape, device, true)
}
@@ -259,11 +259,7 @@ impl Tensor {
let dim = a_dims.len();
- // TODO
- // if dim < 2 {
- // return Err(SmeltError::InsufficientRank { minimum_rank: 2 });
- // }
- if b_dims.len() != dim {
+ if dim < 2 || b_dims.len() != dim {
return Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
rhs: rhs.shape().clone(),