summaryrefslogtreecommitdiff
path: root/candle-core/src/backprop.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-06 22:14:52 +0200
committerGitHub <noreply@github.com>2023-08-06 21:14:52 +0100
commit166bfd5847144abec227836e497b509625470535 (patch)
tree7b13e3dae76c0864a3cb107c98b3a88f24423af3 /candle-core/src/backprop.rs
parent1c062bf06ba504a076b329c965c625be0ec67c1d (diff)
downloadcandle-166bfd5847144abec227836e497b509625470535.tar.gz
candle-166bfd5847144abec227836e497b509625470535.tar.bz2
candle-166bfd5847144abec227836e497b509625470535.zip
Add the recip op + use it in stable-diffusion. (#331)
* Add the recip unary op. * Fix the cuda kernel. * Use the recip op in sigmoid.
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r--candle-core/src/backprop.rs5
1 files changed, 5 insertions, 0 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index f5cc8191..2dff0a5a 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -291,6 +291,11 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&grad)?
}
+ Op::Unary(arg, UnaryOp::Recip) => {
+ let sum_grad = grads.or_insert(arg)?;
+ let grad = (grad / arg.sqr()?)?;
+ *sum_grad = sum_grad.sub(&grad)?
+ }
&Op::Narrow(ref arg, dim, start_idx, len) => {
let arg_dims = arg.dims();
let left_pad = if start_idx == 0 {