summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-21 12:41:08 +0200
committerGitHub <noreply@github.com>2023-07-21 11:41:08 +0100
commit410654525f36e95aebb52462c3ec9bb25826523c (patch)
treeabaab4f031a70777ddda24d125160c7b6099d809
parentc60831aad4f266a320f9854f4ebb3d2d4ab8bc66 (diff)
downloadcandle-410654525f36e95aebb52462c3ec9bb25826523c.tar.gz
candle-410654525f36e95aebb52462c3ec9bb25826523c.tar.bz2
candle-410654525f36e95aebb52462c3ec9bb25826523c.zip
Refactor the reduce ops in order to introduce argmin/argmax. (#212)
* Refactor the reduce ops in order to introduce argmin/argmax. * Clippy fixes. * Use the newly introduced argmax. * Fix the strided case. * Handle the non-contiguous case.
-rw-r--r--candle-core/src/backprop.rs6
-rw-r--r--candle-core/src/cpu_backend.rs221
-rw-r--r--candle-core/src/cuda_backend.rs2
-rw-r--r--candle-core/src/error.rs3
-rw-r--r--candle-core/src/op.rs14
-rw-r--r--candle-core/src/tensor.rs76
-rw-r--r--candle-examples/examples/simple-training/main.rs29
7 files changed, 241 insertions, 110 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 4afaf23b..62cbc488 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -304,6 +304,12 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
+ Op::Reduce(_, ReduceOp::ArgMin, _) => {
+ Err(Error::BackwardNotSupported { op: "argmin" })?
+ }
+ Op::Reduce(_, ReduceOp::ArgMax, _) => {
+ Err(Error::BackwardNotSupported { op: "argmax" })?
+ }
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
Op::Reshape(arg) => {
let arg_grad = grad.reshape(arg.dims())?;
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index b7060f50..7901a7da 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -33,6 +33,26 @@ trait Map1 {
}
}
+trait Map1Any {
+ fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
+ &self,
+ vs: &[T],
+ layout: &Layout,
+ wrap: W,
+ ) -> Result<CpuStorage>;
+
+ fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
+ match vs {
+ CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
+ CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
+ CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
+ CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
+ CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
+ CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?),
+ }
+ }
+}
+
type C = CpuStorage;
trait Map2 {
const OP: &'static str;
@@ -144,11 +164,118 @@ impl<'a> Map2 for WCond<'a> {
}
}
+struct ReduceIndex {
+ reduce_dim_index: usize,
+ use_min: bool,
+ return_index: bool,
+}
+
+impl ReduceIndex {
+ // The value gets replaced if f(s[current_acc], s[i]) returns true.
+ #[inline(always)]
+ fn fold_impl<T, U, F, G>(&self, src: &[T], src_l: &Layout, f: F, g: G) -> Result<Vec<U>>
+ where
+ T: Clone + Copy,
+ U: Clone + Copy,
+ F: Fn(T, T) -> bool,
+ G: Fn(T, usize) -> U,
+ {
+ let reduce_dim_size = src_l.dims()[self.reduce_dim_index];
+ let reduce_dim_stride = src_l.stride()[self.reduce_dim_index];
+ let dst_len = src_l.shape().elem_count() / reduce_dim_size;
+ let mut dst: Vec<U> = Vec::with_capacity(dst_len);
+ let dst_to_set = dst.spare_capacity_mut();
+ let dst_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(dst_to_set) };
+ match src_l.contiguous_offsets() {
+ Some((o1, o2)) => {
+ let src = &src[o1..o2];
+ if reduce_dim_stride == 1 {
+ for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
+ let start_src_i = start_src_i * reduce_dim_size;
+ let src = &src[start_src_i..start_src_i + reduce_dim_size];
+ let mut acc = 0;
+ let mut val = src[0];
+ for (src_i, &s) in src.iter().enumerate() {
+ if f(val, s) {
+ acc = src_i;
+ val = s
+ }
+ }
+ *dst_v = g(val, acc)
+ }
+ } else {
+ for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
+ let (p, q) = (
+ start_src_i / reduce_dim_stride,
+ start_src_i % reduce_dim_stride,
+ );
+ // start_src_i = p * reduce_dim_stride + q
+ let start_src_i = p * reduce_dim_stride * reduce_dim_size + q;
+ let src = &src[start_src_i..];
+ let mut acc = 0;
+ let mut val = src[0];
+ for src_i in 0..reduce_dim_size {
+ let s = src[src_i * reduce_dim_stride];
+ if f(val, s) {
+ acc = src_i;
+ val = s
+ }
+ }
+ *dst_v = g(val, acc)
+ }
+ }
+ }
+ None => {
+ let l = src_l.narrow(self.reduce_dim_index, 0, 1)?;
+ for (unstr_index, src_index) in l.strided_index().enumerate() {
+ let src = &src[src_index..];
+ let mut acc = 0;
+ let mut val = src[0];
+ for src_i in 0..reduce_dim_size {
+ let s = src[src_i * reduce_dim_stride];
+ if f(val, s) {
+ acc = src_i;
+ val = s
+ }
+ }
+ dst[unstr_index] = g(val, acc)
+ }
+ }
+ }
+ unsafe { dst.set_len(dst_len) };
+ Ok(dst)
+ }
+}
+
+impl Map1Any for ReduceIndex {
+ #[inline(always)]
+ fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
+ &self,
+ src: &[T],
+ src_l: &Layout,
+ wrap: W,
+ ) -> Result<CpuStorage> {
+ if src_l.shape().elem_count() == 0 {
+ Err(Error::EmptyTensor { op: "reduce" }.bt())?
+ }
+ let dst = match (self.return_index, self.use_min) {
+ (false, true) => wrap(self.fold_impl(src, src_l, |x, y| x > y, |v, _i| v)?),
+ (false, false) => wrap(self.fold_impl(src, src_l, |x, y| x < y, |v, _i| v)?),
+ (true, true) => {
+ CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x > y, |_v, i| i as u32)?)
+ }
+ (true, false) => {
+ CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x < y, |_v, i| i as u32)?)
+ }
+ };
+ Ok(dst)
+ }
+}
+
struct Reduce<'a> {
dst_shape: &'a Shape,
reduce_dims: &'a [usize],
reduce_dims_and_stride: Vec<(usize, usize)>,
- op: ReduceOp,
}
impl<'a> Reduce<'a> {
@@ -217,25 +344,7 @@ impl<'a> Reduce<'a> {
impl<'a> Map1 for Reduce<'a> {
#[inline(always)]
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
- match self.op {
- ReduceOp::Min => {
- let s = if src_l.shape().elem_count() != 0 {
- src[src_l.start_offset()]
- } else {
- Err(Error::EmptyTensor { op: "min" }.bt())?
- };
- self.fold_impl(src, src_l, s, |x, y| if x < y { x } else { y })
- }
- ReduceOp::Max => {
- let s = if src_l.shape().elem_count() != 0 {
- src[src_l.start_offset()]
- } else {
- Err(Error::EmptyTensor { op: "max" }.bt())?
- };
- self.fold_impl(src, src_l, s, |x, y| if x > y { x } else { y })
- }
- ReduceOp::Sum => self.fold_impl(src, src_l, T::zero(), |x, y| x + y),
- }
+ self.fold_impl(src, src_l, T::zero(), |x, y| x + y)
}
}
@@ -1144,27 +1253,59 @@ impl BackendStorage for CpuStorage {
}
fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result<Self> {
- let src_dims = layout.dims();
- let mut dst_dims = src_dims.to_vec();
- for &dim in reduce_dims.iter() {
- dst_dims[dim] = 1;
- }
- let dst_shape = Shape::from(dst_dims);
- let mut reduce_dims = reduce_dims.to_vec();
- // Sort the reduce_dims as they have to be processed from left to right when converting the
- // indexes.
- reduce_dims.sort();
- let reduce_dims_and_stride: Vec<_> = reduce_dims
- .iter()
- .map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
- .collect();
- Reduce {
- dst_shape: &dst_shape,
- reduce_dims: &reduce_dims,
- reduce_dims_and_stride,
- op,
+ match op {
+ ReduceOp::Sum => {
+ let src_dims = layout.dims();
+ let mut dst_dims = src_dims.to_vec();
+ for &dim in reduce_dims.iter() {
+ dst_dims[dim] = 1;
+ }
+ let dst_shape = Shape::from(dst_dims);
+ let mut reduce_dims = reduce_dims.to_vec();
+ // Sort the reduce_dims as they have to be processed from left to right when converting the
+ // indexes.
+ reduce_dims.sort();
+ let reduce_dims_and_stride: Vec<_> = reduce_dims
+ .iter()
+ .map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
+ .collect();
+ Reduce {
+ dst_shape: &dst_shape,
+ reduce_dims: &reduce_dims,
+ reduce_dims_and_stride,
+ }
+ .map(self, layout)
+ }
+ ReduceOp::Min | ReduceOp::ArgMin | ReduceOp::Max | ReduceOp::ArgMax => {
+ let reduce_dim_index = match reduce_dims {
+ [reduce_dim_index] => *reduce_dim_index,
+ _ => {
+ let op = match op {
+ ReduceOp::Min => "min",
+ ReduceOp::ArgMin => "argmin",
+ ReduceOp::Max => "max",
+ ReduceOp::ArgMax => "argmax",
+ _ => unreachable!(),
+ };
+ let dims = reduce_dims.to_vec();
+ Err(Error::OnlySingleDimension { op, dims })?
+ }
+ };
+ let (use_min, return_index) = match op {
+ ReduceOp::Min => (true, false),
+ ReduceOp::ArgMin => (true, true),
+ ReduceOp::Max => (false, false),
+ ReduceOp::ArgMax => (false, true),
+ _ => unreachable!(),
+ };
+ ReduceIndex {
+ reduce_dim_index,
+ use_min,
+ return_index,
+ }
+ .map(self, layout)
+ }
}
- .map(self, layout)
}
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index e40f5f71..cdbfd0c6 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -562,6 +562,8 @@ impl<'a> Map1 for FastReduce<'a> {
ReduceOp::Sum => "fast_sum",
ReduceOp::Min => "fast_min",
ReduceOp::Max => "fast_max",
+ ReduceOp::ArgMin => "fast_argmin",
+ ReduceOp::ArgMax => "fast_argmax",
};
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
// SAFETY: filled in by the follow up kernel.
diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs
index acbe28d3..23f2642d 100644
--- a/candle-core/src/error.rs
+++ b/candle-core/src/error.rs
@@ -79,6 +79,9 @@ pub enum Error {
nth_shape: Shape,
},
+ #[error("{op} can only be performed on a single dimension")]
+ OnlySingleDimension { op: &'static str, dims: Vec<usize> },
+
#[error("empty tensor for {op}")]
EmptyTensor { op: &'static str },
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index 4686e57e..226cff41 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -17,6 +17,20 @@ pub enum ReduceOp {
Sum,
Min,
Max,
+ ArgMin,
+ ArgMax,
+}
+
+impl ReduceOp {
+ pub(crate) fn name(&self) -> &'static str {
+ match self {
+ Self::ArgMax => "argmax",
+ Self::ArgMin => "argmin",
+ Self::Min => "min",
+ Self::Max => "max",
+ Self::Sum => "sum",
+ }
+ }
}
// These ops return the same type as their input type.
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index f72404df..42d660f4 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -628,47 +628,21 @@ impl Tensor {
}
}
- fn max_impl<D: Dims>(&self, max_dims: D, keepdim: bool) -> Result<Self> {
- let max_dims = max_dims.to_indexes(self.shape(), "max")?;
- let storage = self
- .storage()
- .reduce_op(ReduceOp::Max, self.layout(), &max_dims)?;
- let mut dims = self.dims().to_vec();
- for &max_dim in max_dims.iter() {
- dims[max_dim] = 1
- }
- let op = if self.track_op() {
- Some(Op::Reduce(self.clone(), ReduceOp::Max, dims.to_vec()))
- } else {
- None
- };
- let max = from_storage(storage, dims, op, false);
- if keepdim {
- Ok(max)
- } else {
- max.squeeze_dims(&max_dims)
- }
- }
-
- fn min_impl<D: Dims>(&self, min_dims: D, keepdim: bool) -> Result<Self> {
- let min_dims = min_dims.to_indexes(self.shape(), "min")?;
- let storage = self
- .storage()
- .reduce_op(ReduceOp::Min, self.layout(), &min_dims)?;
+ fn reduce_impl<D: Dim>(&self, dim: D, keepdim: bool, op: ReduceOp) -> Result<Self> {
+ let dim = dim.to_index(self.shape(), op.name())?;
+ let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
let mut dims = self.dims().to_vec();
- for &min_dim in min_dims.iter() {
- dims[min_dim] = 1
- }
+ dims[dim] = 1;
let op = if self.track_op() {
- Some(Op::Reduce(self.clone(), ReduceOp::Min, dims.to_vec()))
+ Some(Op::Reduce(self.clone(), op, dims.to_vec()))
} else {
None
};
- let min = from_storage(storage, dims, op, false);
+ let res = from_storage(storage, dims, op, false);
if keepdim {
- Ok(min)
+ Ok(res)
} else {
- min.squeeze_dims(&min_dims)
+ res.squeeze_dims(&[dim])
}
}
@@ -722,30 +696,36 @@ impl Tensor {
self.sum_impl(sum_dims, false)
}
- pub fn max_keepdim<D: Dims>(&self, max_dims: D) -> Result<Self> {
- self.max_impl(max_dims, true)
+ pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
+ self.reduce_impl(dim, true, ReduceOp::Max)
}
- pub fn max<D: Dims>(&self, max_dims: D) -> Result<Self> {
- self.max_impl(max_dims, false)
+ pub fn max<D: Dim>(&self, dim: D) -> Result<Self> {
+ self.reduce_impl(dim, false, ReduceOp::Max)
}
- pub fn max_all(&self) -> Result<Tensor> {
- let dims: Vec<_> = (0..self.rank()).collect();
- self.max(dims)
+ pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
+ self.reduce_impl(dim, true, ReduceOp::Min)
}
- pub fn min_keepdim<D: Dims>(&self, min_dims: D) -> Result<Self> {
- self.min_impl(min_dims, true)
+ pub fn min<D: Dim>(&self, dim: D) -> Result<Self> {
+ self.reduce_impl(dim, false, ReduceOp::Min)
}
- pub fn min<D: Dims>(&self, min_dims: D) -> Result<Self> {
- self.min_impl(min_dims, false)
+ pub fn argmax_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
+ self.reduce_impl(dim, true, ReduceOp::ArgMax)
}
- pub fn min_all(&self) -> Result<Tensor> {
- let dims: Vec<_> = (0..self.rank()).collect();
- self.min(dims)
+ pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self> {
+ self.reduce_impl(dim, false, ReduceOp::ArgMax)
+ }
+
+ pub fn argmin_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
+ self.reduce_impl(dim, true, ReduceOp::ArgMin)
+ }
+
+ pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self> {
+ self.reduce_impl(dim, false, ReduceOp::ArgMin)
}
pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> {
diff --git a/candle-examples/examples/simple-training/main.rs b/candle-examples/examples/simple-training/main.rs
index 767266f6..ea2dc0cd 100644
--- a/candle-examples/examples/simple-training/main.rs
+++ b/candle-examples/examples/simple-training/main.rs
@@ -42,7 +42,7 @@ pub fn main() -> Result<()> {
let bs = Var::zeros(LABELS, DType::F32, &dev)?;
let sgd = candle_nn::SGD::new(&[&ws, &bs], 1.0);
let test_images = m.test_images;
- let test_labels = m.test_labels.to_vec1::<u8>()?;
+ let test_labels = m.test_labels.to_dtype(DType::U32)?;
for epoch in 1..200 {
let logits = train_images.matmul(&ws)?.broadcast_add(&bs)?;
let log_sm = log_softmax(&logits, D::Minus1)?;
@@ -52,28 +52,13 @@ pub fn main() -> Result<()> {
sgd.backward_step(&loss)?;
let test_logits = test_images.matmul(&ws)?.broadcast_add(&bs)?;
- /* TODO: Add argmax so that the following can be computed within candle.
- let test_accuracy = test_logits
- .argmax(Some(-1), false)
- .eq_tensor(&test_labels)
- .to_kind(Kind::Float)
- .mean(Kind::Float)
- .double_value(&[]);
- */
- let test_logits = test_logits.to_vec2::<f32>()?;
let sum_ok = test_logits
- .iter()
- .zip(test_labels.iter())
- .map(|(logits, label)| {
- let arg_max = logits
- .iter()
- .enumerate()
- .max_by(|(_, v1), (_, v2)| v1.total_cmp(v2))
- .map(|(idx, _)| idx);
- f64::from(arg_max == Some(*label as usize))
- })
- .sum::<f64>();
- let test_accuracy = sum_ok / test_labels.len() as f64;
+ .argmax(D::Minus1)?
+ .eq(&test_labels)?
+ .to_dtype(DType::F32)?
+ .sum_all()?
+ .to_scalar::<f32>()?;
+ let test_accuracy = sum_ok / test_labels.shape().r1()? as f32;
println!(
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
loss.to_scalar::<f32>()?,