summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/error.rs6
-rw-r--r--src/tensor.rs6
2 files changed, 11 insertions, 1 deletions
diff --git a/src/error.rs b/src/error.rs
index 6f40622c..723edaa1 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -12,6 +12,11 @@ pub enum Error {
#[error("the candle crate has not been built with cuda support")]
NotCompiledWithCudaSupport,
+ #[error(
+ "Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}"
+ )]
+ ShapeMismatch { buffer_size: usize, shape: Shape },
+
#[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
ShapeMismatchBinaryOp {
lhs: Shape,
@@ -40,6 +45,7 @@ pub enum Error {
shape: Shape,
},
+ // TODO this is temporary when we support arbitrary matmul
#[error("temporary error where matmul doesn't support arbitrary striding")]
UnexpectedStriding,
diff --git a/src/tensor.rs b/src/tensor.rs
index 571b0399..40b72c00 100644
--- a/src/tensor.rs
+++ b/src/tensor.rs
@@ -151,7 +151,11 @@ impl Tensor {
device: &Device,
is_variable: bool,
) -> Result<Self> {
- // let shape = array.shape()?;
+ let n: usize = shape.0.iter().product();
+ let buffer_size: usize = array.shape()?.0.iter().product();
+ if buffer_size != n {
+ return Err(Error::ShapeMismatch { buffer_size, shape });
+ }
let storage = device.storage(array)?;
let stride = shape.stride_contiguous();
let tensor_ = Tensor_ {