diff options
author | Thomas Santerre <thomas@santerre.xyz> | 2024-04-04 16:32:47 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-04 22:32:47 +0200 |
commit | c5626b827147e5029c6bd3e37352ec8ac501cfc3 (patch) | |
tree | cabc1c5d8713c52159e503b87277410d1b776d9a /candle-core/src | |
parent | e6a5b82ba6507e7e21d5a5d45241bd8f005609b7 (diff) | |
download | candle-c5626b827147e5029c6bd3e37352ec8ac501cfc3.tar.gz candle-c5626b827147e5029c6bd3e37352ec8ac501cfc3.tar.bz2 candle-c5626b827147e5029c6bd3e37352ec8ac501cfc3.zip |
Add support for "sign" on tensors (#2012)
* add the sign unary operator
* remove uneeded import
* remove uneeded import
* undo formatting
* undo formatting
* remove unnecessary redefintion
* allow gradient to flow through for sign and round
* fix cpu ops to ensure that negzero and positive zero are handled properly
* clippy fixes
* Properly avoid gradient tracking.
* Use a branchless version.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/backprop.rs | 18 | ||||
-rw-r--r-- | candle-core/src/metal_backend/mod.rs | 4 | ||||
-rw-r--r-- | candle-core/src/op.rs | 36 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 1 |
4 files changed, 49 insertions, 10 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index f39eedbb..65d91849 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -112,7 +112,8 @@ impl Tensor { } Op::Unary(_node, UnaryOp::Ceil) | Op::Unary(_node, UnaryOp::Floor) - | Op::Unary(_node, UnaryOp::Round) => nodes, + | Op::Unary(_node, UnaryOp::Round) + | Op::Unary(_node, UnaryOp::Sign) => nodes, Op::Reshape(node) | Op::UpsampleNearest1D { arg: node, .. } | Op::UpsampleNearest2D { arg: node, .. } @@ -488,7 +489,6 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad)?; } - Op::Cmp(_args, _) => {} Op::Reduce(arg, ReduceOp::Max, reduced_dims) => { let node = broadcast_back(arg, node, reduced_dims)?; let grad = broadcast_back(arg, &grad, reduced_dims)?; @@ -578,20 +578,18 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } - Op::Reduce(_, ReduceOp::ArgMin, _) => {} - Op::Reduce(_, ReduceOp::ArgMax, _) => {} + Op::Unary(_, UnaryOp::Floor) + | Op::Unary(_, UnaryOp::Round) + | Op::Reduce(_, ReduceOp::ArgMin, _) + | Op::Reduce(_, ReduceOp::ArgMax, _) + | Op::Unary(_, UnaryOp::Sign) + | Op::Cmp(_, _) => {} Op::Reshape(arg) => { let arg_grad = grad.reshape(arg.dims())?; let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?, - Op::Unary(_, UnaryOp::Floor) => { - Err(Error::BackwardNotSupported { op: "floor" })? - } - Op::Unary(_, UnaryOp::Round) => { - Err(Error::BackwardNotSupported { op: "round" })? - } Op::Unary(arg, UnaryOp::Gelu) => { let sum_grad = grads.or_insert(arg)?; let cube = arg.powf(3.)?; diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index fa6973b4..0e058b45 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -497,6 +497,10 @@ impl BackendStorage for MetalStorage { ("utanh", DType::F16) => contiguous::tanh::HALF, ("utanh", DType::F32) => contiguous::tanh::FLOAT, ("utanh", DType::BF16) => contiguous::tanh::BFLOAT, + ("usign", DType::F16) => contiguous::sign::HALF, + ("usign", DType::F32) => contiguous::sign::FLOAT, + ("usign", DType::BF16) => contiguous::sign::BFLOAT, + ("usign", DType::I64) => contiguous::sign::I64, (name, dtype) => { crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 776f5182..49ba44be 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -66,6 +66,7 @@ pub enum UnaryOp { Floor, Ceil, Round, + Sign, } #[derive(Clone)] @@ -254,6 +255,7 @@ pub(crate) struct Tanh; pub(crate) struct Floor; pub(crate) struct Ceil; pub(crate) struct Round; +pub(crate) struct Sign; macro_rules! bin_op { ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => { @@ -925,3 +927,37 @@ impl std::ops::Deref for BackpropOp { &self.0 } } + +impl UnaryOpT for Sign { + const NAME: &'static str = "sign"; + const KERNEL: &'static str = "usign"; + const V: Self = Sign; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + bf16::from((v > bf16::ZERO) as i8) - bf16::from((v < bf16::ZERO) as i8) + } + #[inline(always)] + fn f16(v: f16) -> f16 { + f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8) + } + #[inline(always)] + fn f32(v: f32) -> f32 { + f32::from(v > 0.) - f32::from(v < 0.) + } + #[inline(always)] + fn f64(v: f64) -> f64 { + f64::from(v > 0.) - f64::from(v < 0.) + } + #[inline(always)] + fn u8(v: u8) -> u8 { + u8::min(1, v) + } + #[inline(always)] + fn u32(v: u32) -> u32 { + u32::min(1, v) + } + #[inline(always)] + fn i64(v: i64) -> i64 { + (v > 0) as i64 - (v < 0) as i64 + } +} diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index b53b0419..a5a9dbb1 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -510,6 +510,7 @@ impl Tensor { unary_op!(ceil, Ceil); unary_op!(floor, Floor); unary_op!(round, Round); + unary_op!(sign, Sign); /// Round element of the input tensor to the nearest integer. /// |