summaryrefslogtreecommitdiff
path: root/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/tensor.rs')
-rw-r--r--src/tensor.rs6
1 files changed, 5 insertions, 1 deletions
diff --git a/src/tensor.rs b/src/tensor.rs
index 571b0399..40b72c00 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -151,7 +151,11 @@ impl Tensor {
device: &Device,
is_variable: bool,
) -> Result<Self> {
- // let shape = array.shape()?;
+ let n: usize = shape.0.iter().product();
+ let buffer_size: usize = array.shape()?.0.iter().product();
+ if buffer_size != n {
+ return Err(Error::ShapeMismatch { buffer_size, shape });
+ }
let storage = device.storage(array)?;
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {