diff options
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r-- | candle-core/src/backprop.rs | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index dfad5f62..7488d939 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -471,7 +471,15 @@ impl Tensor { Op::Unary(_, UnaryOp::Round) => { Err(Error::BackwardNotSupported { op: "round" })? } - Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?, + Op::Unary(arg, UnaryOp::Gelu) => { + let sum_grad = grads.or_insert(arg)?; + let cube = arg.powf(3.)?; + let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?; + let gelu_grad = (((0.5 * &tanh)? + + (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))? + + 0.5)?; + *sum_grad = sum_grad.add(&(&grad * gelu_grad)?)? + } Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?, Op::Unary(_, UnaryOp::GeluErf) => { Err(Error::BackwardNotSupported { op: "gelu-erf" })? |