summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-21 06:19:10 +0100
committerGitHub <noreply@github.com>2023-09-21 06:19:10 +0100
commit7b26e513f15a0c7cd55ccfe48525bda1079427ce (patch)
tree6dabc58fb21552ddaf061c88406e82a4d3427ac3 /candle-core/src
parentab1d40ea97b387f0dd05f77db37c840a4d624a08 (diff)
downloadcandle-7b26e513f15a0c7cd55ccfe48525bda1079427ce.tar.gz
candle-7b26e513f15a0c7cd55ccfe48525bda1079427ce.tar.bz2
candle-7b26e513f15a0c7cd55ccfe48525bda1079427ce.zip
Add the erf function. (#917)
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/backprop.rs1
-rw-r--r--candle-core/src/op.rs36
-rw-r--r--candle-core/src/tensor.rs1
3 files changed, 38 insertions, 0 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 3e2ae1ed..a2548198 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -442,6 +442,7 @@ impl Tensor {
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
+ Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
Op::Unary(_, UnaryOp::GeluErf) => {
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
}
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index 26dc6609..4882a205 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -59,6 +59,7 @@ pub enum UnaryOp {
Sqrt,
Gelu,
GeluErf,
+ Erf,
Relu,
Tanh,
}
@@ -327,6 +328,7 @@ pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
pub(crate) struct GeluErf;
+pub(crate) struct Erf;
pub(crate) struct Relu;
pub(crate) struct Tanh;
@@ -623,6 +625,40 @@ impl UnaryOpT for Gelu {
}
}
+impl UnaryOpT for Erf {
+ const NAME: &'static str = "erf";
+ const KERNEL: &'static str = "uerf";
+ const V: Self = Erf;
+ #[inline(always)]
+ fn bf16(v: bf16) -> bf16 {
+ bf16::from_f64(Self::f64(v.to_f64()))
+ }
+ #[inline(always)]
+ fn f16(v: f16) -> f16 {
+ f16::from_f64(Self::f64(v.to_f64()))
+ }
+ #[inline(always)]
+ fn f32(v: f32) -> f32 {
+ Self::f64(v as f64) as f32
+ }
+ #[inline(always)]
+ fn f64(v: f64) -> f64 {
+ crate::cpu::erf::erf(v)
+ }
+ #[inline(always)]
+ fn u8(_: u8) -> u8 {
+ 0
+ }
+ #[inline(always)]
+ fn u32(_: u32) -> u32 {
+ 0
+ }
+ #[inline(always)]
+ fn i64(_: i64) -> i64 {
+ 0
+ }
+}
+
impl UnaryOpT for GeluErf {
const NAME: &'static str = "gelu_erf";
const KERNEL: &'static str = "ugelu_erf";
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index eafd7002..9dccf2b5 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -490,6 +490,7 @@ impl Tensor {
unary_op!(sqrt, Sqrt);
unary_op!(gelu, Gelu);
unary_op!(gelu_erf, GeluErf);
+ unary_op!(erf, Erf);
unary_op!(relu, Relu);
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple