summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs76
1 files changed, 28 insertions, 48 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index f72404df..42d660f4 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -628,47 +628,21 @@ impl Tensor {
}
}
- fn max_impl<D: Dims>(&self, max_dims: D, keepdim: bool) -> Result<Self> {
- let max_dims = max_dims.to_indexes(self.shape(), "max")?;
- let storage = self
- .storage()
- .reduce_op(ReduceOp::Max, self.layout(), &max_dims)?;
- let mut dims = self.dims().to_vec();
- for &max_dim in max_dims.iter() {
- dims[max_dim] = 1
- }
- let op = if self.track_op() {
- Some(Op::Reduce(self.clone(), ReduceOp::Max, dims.to_vec()))
- } else {
- None
- };
- let max = from_storage(storage, dims, op, false);
- if keepdim {
- Ok(max)
- } else {
- max.squeeze_dims(&max_dims)
- }
- }
-
- fn min_impl<D: Dims>(&self, min_dims: D, keepdim: bool) -> Result<Self> {
- let min_dims = min_dims.to_indexes(self.shape(), "min")?;
- let storage = self
- .storage()
- .reduce_op(ReduceOp::Min, self.layout(), &min_dims)?;
+ fn reduce_impl<D: Dim>(&self, dim: D, keepdim: bool, op: ReduceOp) -> Result<Self> {
+ let dim = dim.to_index(self.shape(), op.name())?;
+ let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
let mut dims = self.dims().to_vec();
- for &min_dim in min_dims.iter() {
- dims[min_dim] = 1
- }
+ dims[dim] = 1;
let op = if self.track_op() {
- Some(Op::Reduce(self.clone(), ReduceOp::Min, dims.to_vec()))
+ Some(Op::Reduce(self.clone(), op, dims.to_vec()))
} else {
None
};
- let min = from_storage(storage, dims, op, false);
+ let res = from_storage(storage, dims, op, false);
if keepdim {
- Ok(min)
+ Ok(res)
} else {
- min.squeeze_dims(&min_dims)
+ res.squeeze_dims(&[dim])
}
}
@@ -722,30 +696,36 @@ impl Tensor {
self.sum_impl(sum_dims, false)
}
- pub fn max_keepdim<D: Dims>(&self, max_dims: D) -> Result<Self> {
- self.max_impl(max_dims, true)
+ pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
+ self.reduce_impl(dim, true, ReduceOp::Max)
}
- pub fn max<D: Dims>(&self, max_dims: D) -> Result<Self> {
- self.max_impl(max_dims, false)
+ pub fn max<D: Dim>(&self, dim: D) -> Result<Self> {
+ self.reduce_impl(dim, false, ReduceOp::Max)
}
- pub fn max_all(&self) -> Result<Tensor> {
- let dims: Vec<_> = (0..self.rank()).collect();
- self.max(dims)
+ pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
+ self.reduce_impl(dim, true, ReduceOp::Min)
}
- pub fn min_keepdim<D: Dims>(&self, min_dims: D) -> Result<Self> {
- self.min_impl(min_dims, true)
+ pub fn min<D: Dim>(&self, dim: D) -> Result<Self> {
+ self.reduce_impl(dim, false, ReduceOp::Min)
}
- pub fn min<D: Dims>(&self, min_dims: D) -> Result<Self> {
- self.min_impl(min_dims, false)
+ pub fn argmax_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
+ self.reduce_impl(dim, true, ReduceOp::ArgMax)
}
- pub fn min_all(&self) -> Result<Tensor> {
- let dims: Vec<_> = (0..self.rank()).collect();
- self.min(dims)
+ pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self> {
+ self.reduce_impl(dim, false, ReduceOp::ArgMax)
+ }
+
+ pub fn argmin_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
+ self.reduce_impl(dim, true, ReduceOp::ArgMin)
+ }
+
+ pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self> {
+ self.reduce_impl(dim, false, ReduceOp::ArgMin)
}
pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> {