diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/backprop.rs | 1 | ||||
-rw-r--r-- | candle-core/src/op.rs | 36 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 1 |
3 files changed, 38 insertions, 0 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 3e2ae1ed..a2548198 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -442,6 +442,7 @@ impl Tensor { *sum_grad = sum_grad.add(&arg_grad)? } Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?, + Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?, Op::Unary(_, UnaryOp::GeluErf) => { Err(Error::BackwardNotSupported { op: "gelu-erf" })? } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 26dc6609..4882a205 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -59,6 +59,7 @@ pub enum UnaryOp { Sqrt, Gelu, GeluErf, + Erf, Relu, Tanh, } @@ -327,6 +328,7 @@ pub(crate) struct Sqr; pub(crate) struct Sqrt; pub(crate) struct Gelu; pub(crate) struct GeluErf; +pub(crate) struct Erf; pub(crate) struct Relu; pub(crate) struct Tanh; @@ -623,6 +625,40 @@ impl UnaryOpT for Gelu { } } +impl UnaryOpT for Erf { + const NAME: &'static str = "erf"; + const KERNEL: &'static str = "uerf"; + const V: Self = Erf; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + bf16::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f16(v: f16) -> f16 { + f16::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] + fn f32(v: f32) -> f32 { + Self::f64(v as f64) as f32 + } + #[inline(always)] + fn f64(v: f64) -> f64 { + crate::cpu::erf::erf(v) + } + #[inline(always)] + fn u8(_: u8) -> u8 { + 0 + } + #[inline(always)] + fn u32(_: u32) -> u32 { + 0 + } + #[inline(always)] + fn i64(_: i64) -> i64 { + 0 + } +} + impl UnaryOpT for GeluErf { const NAME: &'static str = "gelu_erf"; const KERNEL: &'static str = "ugelu_erf"; diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index eafd7002..9dccf2b5 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -490,6 +490,7 @@ impl Tensor { unary_op!(sqrt, Sqrt); unary_op!(gelu, Gelu); unary_op!(gelu_erf, GeluErf); + unary_op!(erf, Erf); unary_op!(relu, Relu); /// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple |