diff options
author | Alexey Gerasev <alexey.gerasev@gmail.com> | 2024-07-16 19:41:16 +0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-16 14:41:16 +0200 |
commit | 6a4741bbf9a0a311e312cf331a4f18fce139feaf (patch) | |
tree | 4418f72ef9d9223e3349b42963990b15430c58a5 /candle-core/src/backprop.rs | |
parent | 30cdd769f9404035235830e602ae01d50f782fb5 (diff) | |
download | candle-6a4741bbf9a0a311e312cf331a4f18fce139feaf.tar.gz candle-6a4741bbf9a0a311e312cf331a4f18fce139feaf.tar.bz2 candle-6a4741bbf9a0a311e312cf331a4f18fce139feaf.zip |
Fix Elu gradient NaN on large input (#2328)
* Fix Elu gradient NaN on large input
* Reuse previously computed exp in Elu
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r-- | candle-core/src/backprop.rs | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index a9d5a6a6..d6293aa4 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -634,7 +634,8 @@ impl Tensor { let zeros = arg.zeros_like()?; let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?; let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?; - let negative_exp_mask = ((negative_mask * arg.exp())? * *alpha)?; + // node == alpha * (e^x - 1) for x <= 0, reuse it + let negative_exp_mask = (negative_mask * (*node + *alpha))?; let combined_mask = (positive_mask + negative_exp_mask)?; *sum_grad = sum_grad.add(&(grad * combined_mask)?)? } |