summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-06-22 12:39:33 +0200
committerNicolas Patry <patry.nicolas@protonmail.com>2023-06-22 12:39:33 +0200
commita8b6c848e010c6cea8710a03b479eb7458d82b52 (patch)
tree116657fb140e7b996ffaad616bca180910b337da /src
parent04cf14f35ae9773d9600ed98c39ada56c726338f (diff)
downloadcandle-a8b6c848e010c6cea8710a03b479eb7458d82b52.tar.gz
candle-a8b6c848e010c6cea8710a03b479eb7458d82b52.tar.bz2
candle-a8b6c848e010c6cea8710a03b479eb7458d82b52.zip
Final updates.
Diffstat (limited to 'src')
-rw-r--r--src/cpu_backend.rs2
-rw-r--r--src/tensor.rs10
2 files changed, 4 insertions, 8 deletions
diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs
index 0eb4270a..2c708389 100644
--- a/src/cpu_backend.rs
+++ b/src/cpu_backend.rs
@@ -128,7 +128,7 @@ impl CpuStorage {
let lhs_batch_stride = &lhs_stride[..rank - 2];
let rhs_batch_stride = &rhs_stride[..rank - 2];
- if lhs_batch_stride != &[a_skip] || rhs_batch_stride != &[b_skip] {
+ if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
// Temporary error before we support abitrary striding.
return Err(Error::UnexpectedStriding);
}
diff --git a/src/tensor.rs b/src/tensor.rs
index 7274c557..571b0399 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -166,12 +166,12 @@ impl Tensor {
}
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
- let shape = array.shape()?.clone();
+ let shape = array.shape()?;
Self::new_impl(array, shape, device, false)
}
pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
- let shape = array.shape()?.clone();
+ let shape = array.shape()?;
Self::new_impl(array, shape, device, true)
}
@@ -259,11 +259,7 @@ impl Tensor {
let dim = a_dims.len();
- // TODO
- // if dim < 2 {
- // return Err(SmeltError::InsufficientRank { minimum_rank: 2 });
- // }
- if b_dims.len() != dim {
+ if dim < 2 || b_dims.len() != dim {
return Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
rhs: rhs.shape().clone(),