diff options
Diffstat (limited to 'src/tensor.rs')
-rw-r--r-- | src/tensor.rs | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/src/tensor.rs b/src/tensor.rs index 571b0399..40b72c00 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -151,7 +151,11 @@ impl Tensor { device: &Device, is_variable: bool, ) -> Result<Self> { - // let shape = array.shape()?; + let n: usize = shape.0.iter().product(); + let buffer_size: usize = array.shape()?.0.iter().product(); + if buffer_size != n { + return Err(Error::ShapeMismatch { buffer_size, shape }); + } let storage = device.storage(array)?; let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { |