diff options
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 76 |
1 files changed, 28 insertions, 48 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f72404df..42d660f4 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -628,47 +628,21 @@ impl Tensor { } } - fn max_impl<D: Dims>(&self, max_dims: D, keepdim: bool) -> Result<Self> { - let max_dims = max_dims.to_indexes(self.shape(), "max")?; - let storage = self - .storage() - .reduce_op(ReduceOp::Max, self.layout(), &max_dims)?; - let mut dims = self.dims().to_vec(); - for &max_dim in max_dims.iter() { - dims[max_dim] = 1 - } - let op = if self.track_op() { - Some(Op::Reduce(self.clone(), ReduceOp::Max, dims.to_vec())) - } else { - None - }; - let max = from_storage(storage, dims, op, false); - if keepdim { - Ok(max) - } else { - max.squeeze_dims(&max_dims) - } - } - - fn min_impl<D: Dims>(&self, min_dims: D, keepdim: bool) -> Result<Self> { - let min_dims = min_dims.to_indexes(self.shape(), "min")?; - let storage = self - .storage() - .reduce_op(ReduceOp::Min, self.layout(), &min_dims)?; + fn reduce_impl<D: Dim>(&self, dim: D, keepdim: bool, op: ReduceOp) -> Result<Self> { + let dim = dim.to_index(self.shape(), op.name())?; + let storage = self.storage().reduce_op(op, self.layout(), &[dim])?; let mut dims = self.dims().to_vec(); - for &min_dim in min_dims.iter() { - dims[min_dim] = 1 - } + dims[dim] = 1; let op = if self.track_op() { - Some(Op::Reduce(self.clone(), ReduceOp::Min, dims.to_vec())) + Some(Op::Reduce(self.clone(), op, dims.to_vec())) } else { None }; - let min = from_storage(storage, dims, op, false); + let res = from_storage(storage, dims, op, false); if keepdim { - Ok(min) + Ok(res) } else { - min.squeeze_dims(&min_dims) + res.squeeze_dims(&[dim]) } } @@ -722,30 +696,36 @@ impl Tensor { self.sum_impl(sum_dims, false) } - pub fn max_keepdim<D: Dims>(&self, max_dims: D) -> Result<Self> { - self.max_impl(max_dims, true) + pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> { + self.reduce_impl(dim, true, ReduceOp::Max) } - pub fn max<D: Dims>(&self, max_dims: D) -> Result<Self> { - self.max_impl(max_dims, false) + pub fn max<D: Dim>(&self, dim: D) -> Result<Self> { + self.reduce_impl(dim, false, ReduceOp::Max) } - pub fn max_all(&self) -> Result<Tensor> { - let dims: Vec<_> = (0..self.rank()).collect(); - self.max(dims) + pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self> { + self.reduce_impl(dim, true, ReduceOp::Min) } - pub fn min_keepdim<D: Dims>(&self, min_dims: D) -> Result<Self> { - self.min_impl(min_dims, true) + pub fn min<D: Dim>(&self, dim: D) -> Result<Self> { + self.reduce_impl(dim, false, ReduceOp::Min) } - pub fn min<D: Dims>(&self, min_dims: D) -> Result<Self> { - self.min_impl(min_dims, false) + pub fn argmax_keepdim<D: Dim>(&self, dim: D) -> Result<Self> { + self.reduce_impl(dim, true, ReduceOp::ArgMax) } - pub fn min_all(&self) -> Result<Tensor> { - let dims: Vec<_> = (0..self.rank()).collect(); - self.min(dims) + pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self> { + self.reduce_impl(dim, false, ReduceOp::ArgMax) + } + + pub fn argmin_keepdim<D: Dim>(&self, dim: D) -> Result<Self> { + self.reduce_impl(dim, true, ReduceOp::ArgMin) + } + + pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self> { + self.reduce_impl(dim, false, ReduceOp::ArgMin) } pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> { |