summaryrefslogtreecommitdiff
path: root/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/tensor.rs')
-rw-r--r--src/tensor.rs6
1 files changed, 3 insertions, 3 deletions
diff --git a/src/tensor.rs b/src/tensor.rs
index 40b72c00..09e5d66c 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -151,8 +151,8 @@ impl Tensor {
device: &Device,
is_variable: bool,
) -> Result<Self> {
- let n: usize = shape.0.iter().product();
- let buffer_size: usize = array.shape()?.0.iter().product();
+ let n: usize = shape.elem_count();
+ let buffer_size: usize = array.shape()?.elem_count();
if buffer_size != n {
return Err(Error::ShapeMismatch { buffer_size, shape });
}
@@ -285,7 +285,7 @@ impl Tensor {
let mut c_shape: Vec<_> = a_dims[..dim - 2].into();
c_shape.extend(&[m, n]);
- let c_shape: Shape = Shape(c_shape);
+ let c_shape = Shape(c_shape);
let batching: usize = a_dims[..dim - 2].iter().product();
let storage = self.storage.matmul_impl(