diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/error.rs | 8 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 41 |
2 files changed, 35 insertions, 14 deletions
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 83d3e66d..637fd8b7 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -10,6 +10,14 @@ pub enum Error { got: DType, }, + #[error("invalid args for narrow: {shape:?}, dim: {dim}, start: {start}, len:{len}")] + NarrowInvalidArgs { + shape: Shape, + dim: usize, + start: usize, + len: usize, + }, + #[error("{op} only supports contiguous tensors")] RequiresContiguous { op: &'static str }, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 6586834c..2f05094b 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -349,21 +349,34 @@ impl Tensor { } /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` - /// ranges from `start` to `start + length`. - pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> { - let op = if self.track_op() { - Some(Op::Narrow(self.clone(), dim, start, length)) + /// ranges from `start` to `start + len`. + pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> { + let dims = self.dims(); + if dim >= dims.len() || start + len > dims[dim] { + Err(Error::NarrowInvalidArgs { + shape: self.shape().clone(), + dim, + start, + len, + })? + } + if start == 0 && dims[dim] == len { + Ok(self.clone()) } else { - None - }; - let tensor_ = Tensor_ { - id: TensorId::new(), - storage: self.storage.clone(), - layout: self.layout().narrow(dim, start, length)?, - op, - is_variable: false, - }; - Ok(Tensor(Arc::new(tensor_))) + let op = if self.track_op() { + Some(Op::Narrow(self.clone(), dim, start, len)) + } else { + None + }; + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: self.layout().narrow(dim, start, len)?, + op, + is_variable: false, + }; + Ok(Tensor(Arc::new(tensor_))) + } } pub fn softmax(&self, dim: usize) -> Result<Self> { |