diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-21 06:19:10 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-21 06:19:10 +0100 |
commit | 7b26e513f15a0c7cd55ccfe48525bda1079427ce (patch) | |
tree | 6dabc58fb21552ddaf061c88406e82a4d3427ac3 /candle-core/src | |
parent | ab1d40ea97b387f0dd05f77db37c840a4d624a08 (diff) | |
download | candle-7b26e513f15a0c7cd55ccfe48525bda1079427ce.tar.gz candle-7b26e513f15a0c7cd55ccfe48525bda1079427ce.tar.bz2 candle-7b26e513f15a0c7cd55ccfe48525bda1079427ce.zip |
Add the erf function. (#917)
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/backprop.rs | 1 | ||||
-rw-r--r-- | candle-core/src/op.rs | 36 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 1 |
3 files changed, 38 insertions, 0 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 3e2ae1ed..a2548198 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -442,6 +442,7 @@ impl Tensor { *sum_grad = sum_grad.add(&arg_grad)? } Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?, + Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?, Op::Unary(_, UnaryOp::GeluErf) => { Err(Error::BackwardNotSupported { op: "gelu-erf" })? } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 26dc6609..4882a205 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -59,6 +59,7 @@ pub enum UnaryOp { Sqrt, Gelu, GeluErf, + Erf, Relu, Tanh, } @@ -327,6 +328,7 @@ pub(crate) struct Sqr; pub(crate) struct Sqrt; pub(crate) struct Gelu; pub(crate) struct GeluErf; +pub(crate) struct Erf; pub(crate) struct Relu; pub(crate) struct Tanh; @@ -623,6 +625,40 @@ impl UnaryOpT for Gelu { } } +impl UnaryOpT for Erf { + const NAME: &'static str = "erf"; + const KERNEL: &'static str = "uerf"; + const V: Self = Erf; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + bf16::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f16(v: f16) -> f16 { + f16::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f32(v: f32) -> f32 { + Self::f64(v as f64) as f32 + } + #[inline(always)] + fn f64(v: f64) -> f64 { + crate::cpu::erf::erf(v) + } + #[inline(always)] + fn u8(_: u8) -> u8 { + 0 + } + #[inline(always)] + fn u32(_: u32) -> u32 { + 0 + } + #[inline(always)] + fn i64(_: i64) -> i64 { + 0 + } +} + impl UnaryOpT for GeluErf { const NAME: &'static str = "gelu_erf"; const KERNEL: &'static str = "ugelu_erf"; diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index eafd7002..9dccf2b5 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -490,6 +490,7 @@ impl Tensor { unary_op!(sqrt, Sqrt); unary_op!(gelu, Gelu); unary_op!(gelu_erf, GeluErf); + unary_op!(erf, Erf); unary_op!(relu, Relu); /// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple |