diff options
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 41 |
1 files changed, 27 insertions, 14 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 6586834c..2f05094b 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -349,21 +349,34 @@ impl Tensor { } /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` - /// ranges from `start` to `start + length`. - pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> { - let op = if self.track_op() { - Some(Op::Narrow(self.clone(), dim, start, length)) + /// ranges from `start` to `start + len`. + pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> { + let dims = self.dims(); + if dim >= dims.len() || start + len > dims[dim] { + Err(Error::NarrowInvalidArgs { + shape: self.shape().clone(), + dim, + start, + len, + })? + } + if start == 0 && dims[dim] == len { + Ok(self.clone()) } else { - None - }; - let tensor_ = Tensor_ { - id: TensorId::new(), - storage: self.storage.clone(), - layout: self.layout().narrow(dim, start, length)?, - op, - is_variable: false, - }; - Ok(Tensor(Arc::new(tensor_))) + let op = if self.track_op() { + Some(Op::Narrow(self.clone(), dim, start, len)) + } else { + None + }; + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: self.layout().narrow(dim, start, len)?, + op, + is_variable: false, + }; + Ok(Tensor(Arc::new(tensor_))) + } } pub fn softmax(&self, dim: usize) -> Result<Self> { |