summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/backend.rs5
-rw-r--r--candle-core/src/backprop.rs15
-rw-r--r--candle-core/src/cpu_backend.rs61
-rw-r--r--candle-core/src/cuda_backend.rs20
-rw-r--r--candle-core/src/dummy_cuda_backend.rs7
-rw-r--r--candle-core/src/op.rs29
-rw-r--r--candle-core/src/storage.rs37
-rw-r--r--candle-core/src/tensor.rs45
8 files changed, 178 insertions, 41 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index 018279b3..307868dd 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -1,3 +1,4 @@
+use crate::op::{CmpOp, ReduceOp};
use crate::{CpuStorage, DType, Layout, Result, Shape};
pub(crate) trait BackendStorage: Sized {
@@ -16,7 +17,9 @@ pub(crate) trait BackendStorage: Sized {
fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
- fn reduce_op(&self, _: crate::op::ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>;
+ fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>;
+
+ fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()>;
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 3de11d35..4d968e7f 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -1,4 +1,5 @@
-use crate::{op::Op, Error, Result, Tensor, TensorId};
+use crate::op::{Op, ReduceOp};
+use crate::{Error, Result, Tensor, TensorId};
use std::collections::HashMap;
impl Tensor {
@@ -66,9 +67,8 @@ impl Tensor {
}
Op::Reshape(node)
| Op::Broadcast(node)
- | Op::Sum(node, _)
- | Op::Max(node, _)
- | Op::Min(node, _)
+ | Op::Cmp(node, _)
+ | Op::Reduce(node, _, _)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
@@ -201,14 +201,15 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.broadcast_add(&arg_grad)?
}
- Op::Sum(arg, _sum_dims) => {
+ Op::Reduce(arg, ReduceOp::Sum, _) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.broadcast_add(&grad)?
}
- Op::Max(_args, _sum_dims) => {
+ Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }),
+ Op::Reduce(_args, ReduceOp::Max, _) => {
return Err(Error::BackwardNotSupported { op: "max" })
}
- Op::Min(_args, _sum_dims) => {
+ Op::Reduce(_args, ReduceOp::Min, _) => {
return Err(Error::BackwardNotSupported { op: "min" })
}
Op::ToDType(arg) => {
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 925ca112..b12e0702 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -1,5 +1,5 @@
use crate::backend::{BackendDevice, BackendStorage};
-use crate::op::{BinaryOp, ReduceOp, UnaryOp};
+use crate::op::{BinaryOp, CmpOp, ReduceOp, UnaryOp};
use crate::{DType, Error, Layout, Result, Shape, WithDType};
use half::{bf16, f16};
@@ -62,6 +62,57 @@ trait Map2 {
}
}
+trait Map2U8 {
+ const OP: &'static str;
+ fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
+
+ fn map(
+ &self,
+ v1: &CpuStorage,
+ l1: &Layout,
+ v2: &CpuStorage,
+ l2: &Layout,
+ ) -> Result<CpuStorage> {
+ match (v1, v2) {
+ (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
+ (C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
+ (C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
+ (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
+ (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
+ (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
+ _ => Err(Error::DTypeMismatchBinaryOp {
+ lhs: v1.dtype(),
+ rhs: v2.dtype(),
+ op: Self::OP,
+ }
+ .bt()),
+ }
+ }
+}
+
+struct Cmp(CmpOp);
+impl Map2U8 for Cmp {
+ const OP: &'static str = "cmp";
+ #[inline(always)]
+ fn f<T: WithDType>(
+ &self,
+ lhs: &[T],
+ lhs_l: &Layout,
+ rhs: &[T],
+ rhs_l: &Layout,
+ ) -> Result<Vec<u8>> {
+ let dst = match self.0 {
+ CmpOp::Eq => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x == y)),
+ CmpOp::Ne => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x != y)),
+ CmpOp::Lt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x < y)),
+ CmpOp::Le => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x <= y)),
+ CmpOp::Gt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x > y)),
+ CmpOp::Ge => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x >= y)),
+ };
+ Ok(dst)
+ }
+}
+
struct WCond<'a>(&'a [u32], &'a Layout);
impl<'a> Map2 for WCond<'a> {
@@ -269,13 +320,13 @@ fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
}
// This function maps over two strided index sequences.
-fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
+fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
lhs_l: &Layout,
rhs_l: &Layout,
lhs: &[T],
rhs: &[T],
mut f: F,
-) -> Vec<T> {
+) -> Vec<U> {
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
.iter()
@@ -1064,6 +1115,10 @@ impl BackendStorage for CpuStorage {
.map(self, layout)
}
+ fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
+ Cmp(op).map(self, lhs_l, rhs, rhs_l)
+ }
+
fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
// [self] stores data in a contiguous way starting at offset 0.
match self {
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index b74137f3..9e47c133 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1,4 +1,5 @@
use crate::backend::{BackendDevice, BackendStorage};
+use crate::op::{CmpOp, ReduceOp};
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
use candle_kernels as kernels;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
@@ -515,7 +516,7 @@ impl<'a> Map1 for Sum<'a> {
}
}
-struct FastReduce<'a>(&'a [usize], crate::op::ReduceOp);
+struct FastReduce<'a>(&'a [usize], ReduceOp);
impl<'a> Map1 for FastReduce<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
@@ -558,9 +559,9 @@ impl<'a> Map1 for FastReduce<'a> {
.w()?;
let src = &src.slice(layout.start_offset()..);
let name = match self.1 {
- crate::op::ReduceOp::Sum => "fast_sum",
- crate::op::ReduceOp::Min => "fast_min",
- crate::op::ReduceOp::Max => "fast_max",
+ ReduceOp::Sum => "fast_sum",
+ ReduceOp::Min => "fast_min",
+ ReduceOp::Max => "fast_max",
};
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
// SAFETY: filled in by the follow up kernel.
@@ -961,17 +962,16 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
- fn reduce_op(
- &self,
- op: crate::op::ReduceOp,
- layout: &Layout,
- sum_dims: &[usize],
- ) -> Result<Self> {
+ fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
let device = self.device().clone();
let slice = FastReduce(sum_dims, op).map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
+ fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
+ Err(CudaError::InternalError("TODO: implement cmp").into())
+ }
+
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into())
}
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index f7cf8ab8..942e82ed 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -1,4 +1,5 @@
#![allow(dead_code)]
+use crate::op::{CmpOp, ReduceOp};
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
#[derive(Debug, Clone)]
@@ -40,7 +41,11 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
- fn reduce_op(&self, _: crate::op::ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
+ fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
+ fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index c5ff8179..ece6969c 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -2,12 +2,31 @@ use crate::Tensor;
use half::{bf16, f16};
use num_traits::float::Float;
+#[derive(Clone, Copy, PartialEq, Eq)]
+pub enum CmpOp {
+ Eq,
+ Ne,
+ Le,
+ Ge,
+ Lt,
+ Gt,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum ReduceOp {
+ Sum,
+ Min,
+ Max,
+}
+
#[derive(Clone)]
pub(crate) enum Op {
Add(Tensor, Tensor),
Mul(Tensor, Tensor),
Sub(Tensor, Tensor),
Div(Tensor, Tensor),
+ Cmp(Tensor, CmpOp),
+ Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
Embedding(Tensor, Tensor),
WhereCond(Tensor, Tensor, Tensor),
@@ -28,9 +47,6 @@ pub(crate) enum Op {
mul: f64,
add: f64,
},
- Sum(Tensor, Vec<usize>),
- Max(Tensor, Vec<usize>),
- Min(Tensor, Vec<usize>),
ToDType(Tensor),
Broadcast(Tensor),
Exp(Tensor),
@@ -356,10 +372,3 @@ impl UnaryOp for Relu {
v
}
}
-
-#[derive(Debug, Clone, Copy, PartialEq, Eq)]
-pub enum ReduceOp {
- Sum,
- Min,
- Max,
-}
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index e689905e..fb72322c 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -1,5 +1,6 @@
use crate::backend::BackendStorage;
-use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
+use crate::op::{self, CmpOp, ReduceOp};
+use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
// We do not want to implement Clone on Storage as cloning may fail because of
// out of memory. Instead try_clone should be used.
@@ -80,12 +81,38 @@ impl Storage {
}
}
- pub(crate) fn reduce_op(
+ pub(crate) fn cmp(
&self,
- op: crate::op::ReduceOp,
- layout: &Layout,
- s: &[usize],
+ op: CmpOp,
+ rhs: &Self,
+ lhs_layout: &Layout,
+ rhs_layout: &Layout,
) -> Result<Self> {
+ self.same_device(rhs, "cmp")?;
+ self.same_dtype(rhs, "cmp")?;
+ match (self, rhs) {
+ (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
+ let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
+ Ok(Self::Cpu(storage))
+ }
+ (Self::Cuda(lhs), Self::Cuda(rhs)) => {
+ let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
+ Ok(Self::Cuda(storage))
+ }
+ (lhs, rhs) => {
+ // Should not happen because of the same device check above but we're defensive
+ // anyway.
+ Err(Error::DeviceMismatchBinaryOp {
+ lhs: lhs.device().location(),
+ rhs: rhs.device().location(),
+ op: "cmp",
+ }
+ .bt())
+ }
+ }
+ }
+
+ pub(crate) fn reduce_op(&self, op: ReduceOp, layout: &Layout, s: &[usize]) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.reduce_op(op, layout, s)?;
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 276a522e..d6c3e9cb 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1,5 +1,5 @@
use crate::backend::{BackendDevice, BackendStorage};
-use crate::op::{Op, ReduceOp};
+use crate::op::{CmpOp, Op, ReduceOp};
use crate::shape::{Dim, Dims};
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
@@ -634,7 +634,7 @@ impl Tensor {
.storage()
.reduce_op(ReduceOp::Max, self.layout(), &max_dims)?;
let op = if self.track_op() {
- Some(Op::Max(self.clone(), max_dims.to_vec()))
+ Some(Op::Reduce(self.clone(), ReduceOp::Max, max_dims.to_vec()))
} else {
None
};
@@ -656,7 +656,7 @@ impl Tensor {
.storage()
.reduce_op(ReduceOp::Min, self.layout(), &min_dims)?;
let op = if self.track_op() {
- Some(Op::Min(self.clone(), min_dims.to_vec()))
+ Some(Op::Reduce(self.clone(), ReduceOp::Min, min_dims.to_vec()))
} else {
None
};
@@ -678,7 +678,7 @@ impl Tensor {
.storage()
.reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
let op = if self.track_op() {
- Some(Op::Sum(self.clone(), sum_dims.to_vec()))
+ Some(Op::Reduce(self.clone(), ReduceOp::Sum, sum_dims.to_vec()))
} else {
None
};
@@ -748,6 +748,43 @@ impl Tensor {
self.min(dims)
}
+ pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> {
+ let shape = self.same_shape_binary_op(rhs, "cmp")?;
+ let storage = self
+ .storage()
+ .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
+ let op = if self.track_op() {
+ Some(Op::Cmp(self.clone(), op))
+ } else {
+ None
+ };
+ Ok(from_storage(storage, shape.dims(), op, false))
+ }
+
+ pub fn eq(&self, rhs: &Self) -> Result<Self> {
+ self.cmp(rhs, CmpOp::Eq)
+ }
+
+ pub fn ne(&self, rhs: &Self) -> Result<Self> {
+ self.cmp(rhs, CmpOp::Ne)
+ }
+
+ pub fn lt(&self, rhs: &Self) -> Result<Self> {
+ self.cmp(rhs, CmpOp::Lt)
+ }
+
+ pub fn gt(&self, rhs: &Self) -> Result<Self> {
+ self.cmp(rhs, CmpOp::Gt)
+ }
+
+ pub fn ge(&self, rhs: &Self) -> Result<Self> {
+ self.cmp(rhs, CmpOp::Ge)
+ }
+
+ pub fn le(&self, rhs: &Self) -> Result<Self> {
+ self.cmp(rhs, CmpOp::Le)
+ }
+
/// Applies a 1D convolution over the input tensor.
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.shape().r3()?;