summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/backprop.rs10
-rw-r--r--candle-core/tests/grad_tests.rs13
2 files changed, 22 insertions, 1 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index dfad5f62..7488d939 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -471,7 +471,15 @@ impl Tensor {
Op::Unary(_, UnaryOp::Round) => {
Err(Error::BackwardNotSupported { op: "round" })?
}
- Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
+ Op::Unary(arg, UnaryOp::Gelu) => {
+ let sum_grad = grads.or_insert(arg)?;
+ let cube = arg.powf(3.)?;
+ let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
+ let gelu_grad = (((0.5 * &tanh)?
+ + (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
+ + 0.5)?;
+ *sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
+ }
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
Op::Unary(_, UnaryOp::GeluErf) => {
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs
index 2a70cfc4..bcfe639f 100644
--- a/candle-core/tests/grad_tests.rs
+++ b/candle-core/tests/grad_tests.rs
@@ -192,6 +192,19 @@ fn unary_grad(device: &Device) -> Result<()> {
test_utils::to_vec1_round(grad_x, 2)?,
[0.01, 0.42, 0.0, 0.98],
);
+
+ // testing compared to pytorch nn.GELU(approximate = 'tanh')
+ let y = x.gelu()?;
+ let grads = y.backward()?;
+ let grad_x = grads.get(&x).context("no grad for x")?;
+ assert_eq!(
+ test_utils::to_vec1_round(&y, 4)?,
+ [2.9964, 0.8412, 3.9999, 0.0839]
+ );
+ assert_eq!(
+ test_utils::to_vec1_round(grad_x, 4)?,
+ [1.0116, 1.0830, 1.0003, 0.6188],
+ );
Ok(())
}