diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-06-22 13:08:57 +0200 |
---|---|---|
committer | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-06-22 13:08:57 +0200 |
commit | 449af49b5404b96ae19e6926921571c459abb183 (patch) | |
tree | 393fee16e4ba51cfd694b2435d44781a775b4b02 /src | |
parent | a8b6c848e010c6cea8710a03b479eb7458d82b52 (diff) | |
download | candle-449af49b5404b96ae19e6926921571c459abb183.tar.gz candle-449af49b5404b96ae19e6926921571c459abb183.tar.bz2 candle-449af49b5404b96ae19e6926921571c459abb183.zip |
Adding size checking when creating a tensor from buffer + shape.
Diffstat (limited to 'src')
-rw-r--r-- | src/error.rs | 6 | ||||
-rw-r--r-- | src/tensor.rs | 6 |
2 files changed, 11 insertions, 1 deletions
diff --git a/src/error.rs b/src/error.rs index 6f40622c..723edaa1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,6 +12,11 @@ pub enum Error { #[error("the candle crate has not been built with cuda support")] NotCompiledWithCudaSupport, + #[error( + "Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}" + )] + ShapeMismatch { buffer_size: usize, shape: Shape }, + #[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] ShapeMismatchBinaryOp { lhs: Shape, @@ -40,6 +45,7 @@ pub enum Error { shape: Shape, }, + // TODO this is temporary when we support arbitrary matmul #[error("temporary error where matmul doesn't support arbitrary striding")] UnexpectedStriding, diff --git a/src/tensor.rs b/src/tensor.rs index 571b0399..40b72c00 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -151,7 +151,11 @@ impl Tensor { device: &Device, is_variable: bool, ) -> Result<Self> { - // let shape = array.shape()?; + let n: usize = shape.0.iter().product(); + let buffer_size: usize = array.shape()?.0.iter().product(); + if buffer_size != n { + return Err(Error::ShapeMismatch { buffer_size, shape }); + } let storage = device.storage(array)?; let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { |