summaryrefslogtreecommitdiff
path: root/candle-core/src/op.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-04 17:58:44 +0100
committerGitHub <noreply@github.com>2023-10-04 17:58:44 +0100
commitc18a856e76cad9626406c3c483a53fb5b7eeef7b (patch)
tree67c71e73d59dd5ab506d98c134492e08bd9e5e68 /candle-core/src/op.rs
parent3349c892523426a00e16dd094837f5d786754ce1 (diff)
downloadcandle-c18a856e76cad9626406c3c483a53fb5b7eeef7b.tar.gz
candle-c18a856e76cad9626406c3c483a53fb5b7eeef7b.tar.bz2
candle-c18a856e76cad9626406c3c483a53fb5b7eeef7b.zip
Add the rounding operators. (#1030)
* Add the rounding operators. * Avoid tracking gradients for the rounding operations. * Add some rounding tests.
Diffstat (limited to 'candle-core/src/op.rs')
-rw-r--r--candle-core/src/op.rs108
1 files changed, 108 insertions, 0 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index 3083d2c8..b7f99f11 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -62,6 +62,9 @@ pub enum UnaryOp {
Erf,
Relu,
Tanh,
+ Floor,
+ Ceil,
+ Round,
}
#[derive(Clone)]
@@ -332,6 +335,9 @@ pub(crate) struct GeluErf;
pub(crate) struct Erf;
pub(crate) struct Relu;
pub(crate) struct Tanh;
+pub(crate) struct Floor;
+pub(crate) struct Ceil;
+pub(crate) struct Round;
macro_rules! bin_op {
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
@@ -660,6 +666,108 @@ impl UnaryOpT for Erf {
}
}
+impl UnaryOpT for Ceil {
+ const NAME: &'static str = "ceil";
+ const KERNEL: &'static str = "uceil";
+ const V: Self = Ceil;
+ #[inline(always)]
+ fn bf16(v: bf16) -> bf16 {
+ v.ceil()
+ }
+ #[inline(always)]
+ fn f16(v: f16) -> f16 {
+ v.ceil()
+ }
+ #[inline(always)]
+ fn f32(v: f32) -> f32 {
+ v.ceil()
+ }
+ #[inline(always)]
+ fn f64(v: f64) -> f64 {
+ v.ceil()
+ }
+ #[inline(always)]
+ fn u8(v: u8) -> u8 {
+ v
+ }
+ #[inline(always)]
+ fn u32(v: u32) -> u32 {
+ v
+ }
+ #[inline(always)]
+ fn i64(v: i64) -> i64 {
+ v
+ }
+}
+
+impl UnaryOpT for Floor {
+ const NAME: &'static str = "floor";
+ const KERNEL: &'static str = "ufloor";
+ const V: Self = Floor;
+ #[inline(always)]
+ fn bf16(v: bf16) -> bf16 {
+ v.floor()
+ }
+ #[inline(always)]
+ fn f16(v: f16) -> f16 {
+ v.floor()
+ }
+ #[inline(always)]
+ fn f32(v: f32) -> f32 {
+ v.floor()
+ }
+ #[inline(always)]
+ fn f64(v: f64) -> f64 {
+ v.floor()
+ }
+ #[inline(always)]
+ fn u8(v: u8) -> u8 {
+ v
+ }
+ #[inline(always)]
+ fn u32(v: u32) -> u32 {
+ v
+ }
+ #[inline(always)]
+ fn i64(v: i64) -> i64 {
+ v
+ }
+}
+
+impl UnaryOpT for Round {
+ const NAME: &'static str = "round";
+ const KERNEL: &'static str = "uround";
+ const V: Self = Round;
+ #[inline(always)]
+ fn bf16(v: bf16) -> bf16 {
+ v.round()
+ }
+ #[inline(always)]
+ fn f16(v: f16) -> f16 {
+ v.round()
+ }
+ #[inline(always)]
+ fn f32(v: f32) -> f32 {
+ v.round()
+ }
+ #[inline(always)]
+ fn f64(v: f64) -> f64 {
+ v.round()
+ }
+ #[inline(always)]
+ fn u8(v: u8) -> u8 {
+ v
+ }
+ #[inline(always)]
+ fn u32(v: u32) -> u32 {
+ v
+ }
+ #[inline(always)]
+ fn i64(v: i64) -> i64 {
+ v
+ }
+}
+
impl UnaryOpT for GeluErf {
const NAME: &'static str = "gelu_erf";
const KERNEL: &'static str = "ugelu_erf";