diff options
author | Kirpal Grewal <45569241+KGrewal1@users.noreply.github.com> | 2024-03-23 06:05:55 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-23 07:05:55 +0100 |
commit | cc856db9ce2541e09731165f88cdd7aae37f558e (patch) | |
tree | 5ffa3b722f84504ba2208c439c2b9253f3ffa1eb /candle-core/src/backprop.rs | |
parent | fc1fe5e45b046771589126c355fdfb4d3bb49fe4 (diff) | |
download | candle-cc856db9ce2541e09731165f88cdd7aae37f558e.tar.gz candle-cc856db9ce2541e09731165f88cdd7aae37f558e.tar.bz2 candle-cc856db9ce2541e09731165f88cdd7aae37f558e.zip |
Backwards for ConvTranspose2D (#1910)
* add documentation for nackprop
* add backwards for ConvTranspose2D
* add test python code to test
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r-- | candle-core/src/backprop.rs | 38 |
1 files changed, 35 insertions, 3 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 2a1db58a..f39eedbb 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -1,3 +1,4 @@ +/// Methods for backpropagation of gradients. use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp}; use crate::{Error, Result, Tensor, TensorId}; use std::collections::HashMap; @@ -310,9 +311,32 @@ impl Tensor { Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported { op: "conv-transpose1d", })?, - Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported { - op: "conv-transpose2d", - })?, + Op::ConvTranspose2D { + arg, + kernel, + padding, + stride, + dilation, + output_padding: _output_padding, + } => { + let grad_arg = grad.conv2d(kernel, *padding, *dilation, *stride, 1)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&grad_arg)?; + + let grad_kernel = grad + .transpose(0, 1)? + .conv2d(&arg.transpose(0, 1)?, *padding, *stride, *dilation, 1)? + .transpose(0, 1)?; + let sum_grad = grads.or_insert(kernel)?; + let (_, _, k0, k1) = kernel.dims4()?; + let (_, _, g_k0, g_k1) = grad_kernel.dims4()?; + let grad_kernel = if g_k0 != k0 || g_k1 != k1 { + grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)? + } else { + grad_kernel + }; + *sum_grad = sum_grad.add(&grad_kernel)?; + } Op::AvgPool2D { arg, kernel_size, @@ -690,30 +714,38 @@ impl Tensor { } } +/// A store for gradients, associating a tensor id to the corresponding gradient tensor, used for back propagation. #[derive(Debug)] pub struct GradStore(HashMap<TensorId, Tensor>); impl GradStore { + /// Create a new gradient store fn new() -> Self { GradStore(HashMap::new()) } + /// Get the gradient tensor corresponding to the given tensor id pub fn get_id(&self, id: TensorId) -> Option<&Tensor> { self.0.get(&id) } + /// Get the gradient tensor associated with the given tensor pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> { self.0.get(&tensor.id()) } + /// Remove the gradient tensor associated with the given tensor, returning it if it exists pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> { self.0.remove(&tensor.id()) } + /// Insert a gradient tensor associated with the given tensor, returning the previous gradient tensor if it existed pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> { self.0.insert(tensor.id(), grad) } + /// Get the gradient tensor associated with the given tensor, or, if it does not exist, + /// insert a tensor of zeroes, with the same shape and type as the given tensors and return it fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> { use std::collections::hash_map::Entry; let grad = match self.0.entry(tensor.id()) { |