summaryrefslogtreecommitdiff
path: root/candle-core/src/layout.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/layout.rs')
-rw-r--r--candle-core/src/layout.rs20
1 files changed, 13 insertions, 7 deletions
diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs
index 3f629d50..79d40cfc 100644
--- a/candle-core/src/layout.rs
+++ b/candle-core/src/layout.rs
@@ -60,20 +60,26 @@ impl Layout {
self.shape.is_fortran_contiguous(&self.stride)
}
- pub(crate) fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
+ pub(crate) fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
let dims = self.shape().dims();
if dim >= dims.len() {
- Err(Error::UnexpectedNumberOfDims {
- expected: dim + 1,
- got: dims.len(),
+ Err(Error::DimOutOfRange {
shape: self.shape().clone(),
+ dim: dim as i32,
+ op: "narrow",
})?
}
- if start + length > dims[dim] {
- todo!("add a proper error: out of bounds for narrow {dim} {start} {length} {dims:?}")
+ if start + len > dims[dim] {
+ Err(Error::NarrowInvalidArgs {
+ shape: self.shape.clone(),
+ dim,
+ start,
+ len,
+ msg: "start + len > dim_len",
+ })?
}
let mut dims = dims.to_vec();
- dims[dim] = length;
+ dims[dim] = len;
Ok(Self {
shape: Shape::from(dims),
stride: self.stride.clone(),