diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-06 22:14:52 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-06 21:14:52 +0100 |
commit | 166bfd5847144abec227836e497b509625470535 (patch) | |
tree | 7b13e3dae76c0864a3cb107c98b3a88f24423af3 /candle-core/src/backprop.rs | |
parent | 1c062bf06ba504a076b329c965c625be0ec67c1d (diff) | |
download | candle-166bfd5847144abec227836e497b509625470535.tar.gz candle-166bfd5847144abec227836e497b509625470535.tar.bz2 candle-166bfd5847144abec227836e497b509625470535.zip |
Add the recip op + use it in stable-diffusion. (#331)
* Add the recip unary op.
* Fix the cuda kernel.
* Use the recip op in sigmoid.
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r-- | candle-core/src/backprop.rs | 5 |
1 files changed, 5 insertions, 0 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index f5cc8191..2dff0a5a 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -291,6 +291,11 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.sub(&grad)? } + Op::Unary(arg, UnaryOp::Recip) => { + let sum_grad = grads.or_insert(arg)?; + let grad = (grad / arg.sqr()?)?; + *sum_grad = sum_grad.sub(&grad)? + } &Op::Narrow(ref arg, dim, start_idx, len) => { let arg_dims = arg.dims(); let left_pad = if start_idx == 0 { |