diff options
Diffstat (limited to 'candle-core/src/layout.rs')
-rw-r--r-- | candle-core/src/layout.rs | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index 3f629d50..79d40cfc 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -60,20 +60,26 @@ impl Layout { self.shape.is_fortran_contiguous(&self.stride) } - pub(crate) fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> { + pub(crate) fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> { let dims = self.shape().dims(); if dim >= dims.len() { - Err(Error::UnexpectedNumberOfDims { - expected: dim + 1, - got: dims.len(), + Err(Error::DimOutOfRange { shape: self.shape().clone(), + dim: dim as i32, + op: "narrow", })? } - if start + length > dims[dim] { - todo!("add a proper error: out of bounds for narrow {dim} {start} {length} {dims:?}") + if start + len > dims[dim] { + Err(Error::NarrowInvalidArgs { + shape: self.shape.clone(), + dim, + start, + len, + msg: "start + len > dim_len", + })? } let mut dims = dims.to_vec(); - dims[dim] = length; + dims[dim] = len; Ok(Self { shape: Shape::from(dims), stride: self.stride.clone(), |