summaryrefslogtreecommitdiff
path: root/candle-core/src/backprop.rs
diff options
context:
space:
mode:
authorAlexey Gerasev <alexey.gerasev@gmail.com>2024-07-16 19:41:16 +0700
committerGitHub <noreply@github.com>2024-07-16 14:41:16 +0200
commit6a4741bbf9a0a311e312cf331a4f18fce139feaf (patch)
tree4418f72ef9d9223e3349b42963990b15430c58a5 /candle-core/src/backprop.rs
parent30cdd769f9404035235830e602ae01d50f782fb5 (diff)
downloadcandle-6a4741bbf9a0a311e312cf331a4f18fce139feaf.tar.gz
candle-6a4741bbf9a0a311e312cf331a4f18fce139feaf.tar.bz2
candle-6a4741bbf9a0a311e312cf331a4f18fce139feaf.zip
Fix Elu gradient NaN on large input (#2328)
* Fix Elu gradient NaN on large input * Reuse previously computed exp in Elu
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r--candle-core/src/backprop.rs3
1 files changed, 2 insertions, 1 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index a9d5a6a6..d6293aa4 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -634,7 +634,8 @@ impl Tensor {
let zeros = arg.zeros_like()?;
let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
- let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?;
+ // node == alpha * (e^x - 1) for x <= 0, reuse it
+ let negative_exp_mask = (negative_mask * (*node + *alpha))?;
let combined_mask = (positive_mask + negative_exp_mask)?;
*sum_grad = sum_grad.add(&(grad * combined_mask)?)?
}