diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-04 17:58:44 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-04 17:58:44 +0100 |
commit | c18a856e76cad9626406c3c483a53fb5b7eeef7b (patch) | |
tree | 67c71e73d59dd5ab506d98c134492e08bd9e5e68 /candle-core/src/op.rs | |
parent | 3349c892523426a00e16dd094837f5d786754ce1 (diff) | |
download | candle-c18a856e76cad9626406c3c483a53fb5b7eeef7b.tar.gz candle-c18a856e76cad9626406c3c483a53fb5b7eeef7b.tar.bz2 candle-c18a856e76cad9626406c3c483a53fb5b7eeef7b.zip |
Add the rounding operators. (#1030)
* Add the rounding operators.
* Avoid tracking gradients for the rounding operations.
* Add some rounding tests.
Diffstat (limited to 'candle-core/src/op.rs')
-rw-r--r-- | candle-core/src/op.rs | 108 |
1 files changed, 108 insertions, 0 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 3083d2c8..b7f99f11 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -62,6 +62,9 @@ pub enum UnaryOp { Erf, Relu, Tanh, + Floor, + Ceil, + Round, } #[derive(Clone)] @@ -332,6 +335,9 @@ pub(crate) struct GeluErf; pub(crate) struct Erf; pub(crate) struct Relu; pub(crate) struct Tanh; +pub(crate) struct Floor; +pub(crate) struct Ceil; +pub(crate) struct Round; macro_rules! bin_op { ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => { @@ -660,6 +666,108 @@ impl UnaryOpT for Erf { } } +impl UnaryOpT for Ceil { + const NAME: &'static str = "ceil"; + const KERNEL: &'static str = "uceil"; + const V: Self = Ceil; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + v.ceil() + } + #[inline(always)] + fn f16(v: f16) -> f16 { + v.ceil() + } + #[inline(always)] + fn f32(v: f32) -> f32 { + v.ceil() + } + #[inline(always)] + fn f64(v: f64) -> f64 { + v.ceil() + } + #[inline(always)] + fn u8(v: u8) -> u8 { + v + } + #[inline(always)] + fn u32(v: u32) -> u32 { + v + } + #[inline(always)] + fn i64(v: i64) -> i64 { + v + } +} + +impl UnaryOpT for Floor { + const NAME: &'static str = "floor"; + const KERNEL: &'static str = "ufloor"; + const V: Self = Floor; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + v.floor() + } + #[inline(always)] + fn f16(v: f16) -> f16 { + v.floor() + } + #[inline(always)] + fn f32(v: f32) -> f32 { + v.floor() + } + #[inline(always)] + fn f64(v: f64) -> f64 { + v.floor() + } + #[inline(always)] + fn u8(v: u8) -> u8 { + v + } + #[inline(always)] + fn u32(v: u32) -> u32 { + v + } + #[inline(always)] + fn i64(v: i64) -> i64 { + v + } +} + +impl UnaryOpT for Round { + const NAME: &'static str = "round"; + const KERNEL: &'static str = "uround"; + const V: Self = Round; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + v.round() + } + #[inline(always)] + fn f16(v: f16) -> f16 { + v.round() + } + #[inline(always)] + fn f32(v: f32) -> f32 { + v.round() + } + #[inline(always)] + fn f64(v: f64) -> f64 { + v.round() + } + #[inline(always)] + fn u8(v: u8) -> u8 { + v + } + #[inline(always)] + fn u32(v: u32) -> u32 { + v + } + #[inline(always)] + fn i64(v: i64) -> i64 { + v + } +} + impl UnaryOpT for GeluErf { const NAME: &'static str = "gelu_erf"; const KERNEL: &'static str = "ugelu_erf"; |