summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
authorThomas Santerre <thomas@santerre.xyz>2024-04-04 16:32:47 -0400
committerGitHub <noreply@github.com>2024-04-04 22:32:47 +0200
commitc5626b827147e5029c6bd3e37352ec8ac501cfc3 (patch)
treecabc1c5d8713c52159e503b87277410d1b776d9a /candle-core/src
parente6a5b82ba6507e7e21d5a5d45241bd8f005609b7 (diff)
downloadcandle-c5626b827147e5029c6bd3e37352ec8ac501cfc3.tar.gz
candle-c5626b827147e5029c6bd3e37352ec8ac501cfc3.tar.bz2
candle-c5626b827147e5029c6bd3e37352ec8ac501cfc3.zip
Add support for "sign" on tensors (#2012)
* add the sign unary operator * remove uneeded import * remove uneeded import * undo formatting * undo formatting * remove unnecessary redefintion * allow gradient to flow through for sign and round * fix cpu ops to ensure that negzero and positive zero are handled properly * clippy fixes * Properly avoid gradient tracking. * Use a branchless version. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/backprop.rs18
-rw-r--r--candle-core/src/metal_backend/mod.rs4
-rw-r--r--candle-core/src/op.rs36
-rw-r--r--candle-core/src/tensor.rs1
4 files changed, 49 insertions, 10 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index f39eedbb..65d91849 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -112,7 +112,8 @@ impl Tensor {
}
Op::Unary(_node, UnaryOp::Ceil)
| Op::Unary(_node, UnaryOp::Floor)
- | Op::Unary(_node, UnaryOp::Round) => nodes,
+ | Op::Unary(_node, UnaryOp::Round)
+ | Op::Unary(_node, UnaryOp::Sign) => nodes,
Op::Reshape(node)
| Op::UpsampleNearest1D { arg: node, .. }
| Op::UpsampleNearest2D { arg: node, .. }
@@ -488,7 +489,6 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad)?;
}
- Op::Cmp(_args, _) => {}
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
let node = broadcast_back(arg, node, reduced_dims)?;
let grad = broadcast_back(arg, &grad, reduced_dims)?;
@@ -578,20 +578,18 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
- Op::Reduce(_, ReduceOp::ArgMin, _) => {}
- Op::Reduce(_, ReduceOp::ArgMax, _) => {}
+ Op::Unary(_, UnaryOp::Floor)
+ | Op::Unary(_, UnaryOp::Round)
+ | Op::Reduce(_, ReduceOp::ArgMin, _)
+ | Op::Reduce(_, ReduceOp::ArgMax, _)
+ | Op::Unary(_, UnaryOp::Sign)
+ | Op::Cmp(_, _) => {}
Op::Reshape(arg) => {
let arg_grad = grad.reshape(arg.dims())?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
- Op::Unary(_, UnaryOp::Floor) => {
- Err(Error::BackwardNotSupported { op: "floor" })?
- }
- Op::Unary(_, UnaryOp::Round) => {
- Err(Error::BackwardNotSupported { op: "round" })?
- }
Op::Unary(arg, UnaryOp::Gelu) => {
let sum_grad = grads.or_insert(arg)?;
let cube = arg.powf(3.)?;
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs
index fa6973b4..0e058b45 100644
--- a/candle-core/src/metal_backend/mod.rs
+++ b/candle-core/src/metal_backend/mod.rs
@@ -497,6 +497,10 @@ impl BackendStorage for MetalStorage {
("utanh", DType::F16) => contiguous::tanh::HALF,
("utanh", DType::F32) => contiguous::tanh::FLOAT,
("utanh", DType::BF16) => contiguous::tanh::BFLOAT,
+ ("usign", DType::F16) => contiguous::sign::HALF,
+ ("usign", DType::F32) => contiguous::sign::FLOAT,
+ ("usign", DType::BF16) => contiguous::sign::BFLOAT,
+ ("usign", DType::I64) => contiguous::sign::I64,
(name, dtype) => {
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
}
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index 776f5182..49ba44be 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -66,6 +66,7 @@ pub enum UnaryOp {
Floor,
Ceil,
Round,
+ Sign,
}
#[derive(Clone)]
@@ -254,6 +255,7 @@ pub(crate) struct Tanh;
pub(crate) struct Floor;
pub(crate) struct Ceil;
pub(crate) struct Round;
+pub(crate) struct Sign;
macro_rules! bin_op {
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
@@ -925,3 +927,37 @@ impl std::ops::Deref for BackpropOp {
&self.0
}
}
+
+impl UnaryOpT for Sign {
+ const NAME: &'static str = "sign";
+ const KERNEL: &'static str = "usign";
+ const V: Self = Sign;
+ #[inline(always)]
+ fn bf16(v: bf16) -> bf16 {
+ bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8)
+ }
+ #[inline(always)]
+ fn f16(v: f16) -> f16 {
+ f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8)
+ }
+ #[inline(always)]
+ fn f32(v: f32) -> f32 {
+ f32::from(v > 0.) - f32::from(v < 0.)
+ }
+ #[inline(always)]
+ fn f64(v: f64) -> f64 {
+ f64::from(v > 0.) - f64::from(v < 0.)
+ }
+ #[inline(always)]
+ fn u8(v: u8) -> u8 {
+ u8::min(1, v)
+ }
+ #[inline(always)]
+ fn u32(v: u32) -> u32 {
+ u32::min(1, v)
+ }
+ #[inline(always)]
+ fn i64(v: i64) -> i64 {
+ (v > 0) as i64 - (v < 0) as i64
+ }
+}
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index b53b0419..a5a9dbb1 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -510,6 +510,7 @@ impl Tensor {
unary_op!(ceil, Ceil);
unary_op!(floor, Floor);
unary_op!(round, Round);
+ unary_op!(sign, Sign);
/// Round element of the input tensor to the nearest integer.
///