summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/op.rs35
-rw-r--r--candle-core/tests/tensor_tests.rs8
2 files changed, 42 insertions, 1 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index b7f99f11..e1168c2e 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -536,7 +536,6 @@ unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
-unary_op!(Abs, "abs", v, v.abs());
unary_op!(Neg, "neg", v, -v);
unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
@@ -666,6 +665,40 @@ impl UnaryOpT for Erf {
}
}
+impl UnaryOpT for Abs {
+ const NAME: &'static str = "abs";
+ const KERNEL: &'static str = "uabs";
+ const V: Self = Abs;
+ #[inline(always)]
+ fn bf16(v: bf16) -> bf16 {
+ v.abs()
+ }
+ #[inline(always)]
+ fn f16(v: f16) -> f16 {
+ v.abs()
+ }
+ #[inline(always)]
+ fn f32(v: f32) -> f32 {
+ v.abs()
+ }
+ #[inline(always)]
+ fn f64(v: f64) -> f64 {
+ v.abs()
+ }
+ #[inline(always)]
+ fn u8(v: u8) -> u8 {
+ v
+ }
+ #[inline(always)]
+ fn u32(v: u32) -> u32 {
+ v
+ }
+ #[inline(always)]
+ fn i64(v: i64) -> i64 {
+ v.abs()
+ }
+}
+
impl UnaryOpT for Ceil {
const NAME: &'static str = "ceil";
const KERNEL: &'static str = "uceil";
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index ae1bd058..899efcf3 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -1089,3 +1089,11 @@ fn pad_with_same() -> Result<()> {
);
Ok(())
}
+
+#[test]
+fn i64_abs() -> Result<()> {
+ let t = Tensor::new(&[-42i64, 1337], &Device::Cpu)?;
+ let t = t.abs()?;
+ assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
+ Ok(())
+}