summaryrefslogtreecommitdiff
path: root/candle-core/src/backprop.rs
diff options
context:
space:
mode:
authorKirpal Grewal <45569241+KGrewal1@users.noreply.github.com>2024-03-23 06:05:55 +0000
committerGitHub <noreply@github.com>2024-03-23 07:05:55 +0100
commitcc856db9ce2541e09731165f88cdd7aae37f558e (patch)
tree5ffa3b722f84504ba2208c439c2b9253f3ffa1eb /candle-core/src/backprop.rs
parentfc1fe5e45b046771589126c355fdfb4d3bb49fe4 (diff)
downloadcandle-cc856db9ce2541e09731165f88cdd7aae37f558e.tar.gz
candle-cc856db9ce2541e09731165f88cdd7aae37f558e.tar.bz2
candle-cc856db9ce2541e09731165f88cdd7aae37f558e.zip
Backwards for ConvTranspose2D (#1910)
* add documentation for nackprop * add backwards for ConvTranspose2D * add test python code to test
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r--candle-core/src/backprop.rs38
1 files changed, 35 insertions, 3 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 2a1db58a..f39eedbb 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -1,3 +1,4 @@
+/// Methods for backpropagation of gradients.
use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
use crate::{Error, Result, Tensor, TensorId};
use std::collections::HashMap;
@@ -310,9 +311,32 @@ impl Tensor {
Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
op: "conv-transpose1d",
})?,
- Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
- op: "conv-transpose2d",
- })?,
+ Op::ConvTranspose2D {
+ arg,
+ kernel,
+ padding,
+ stride,
+ dilation,
+ output_padding: _output_padding,
+ } => {
+ let grad_arg = grad.conv2d(kernel, *padding, *dilation, *stride, 1)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&grad_arg)?;
+
+ let grad_kernel = grad
+ .transpose(0, 1)?
+ .conv2d(&arg.transpose(0, 1)?, *padding, *stride, *dilation, 1)?
+ .transpose(0, 1)?;
+ let sum_grad = grads.or_insert(kernel)?;
+ let (_, _, k0, k1) = kernel.dims4()?;
+ let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
+ let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
+ grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
+ } else {
+ grad_kernel
+ };
+ *sum_grad = sum_grad.add(&grad_kernel)?;
+ }
Op::AvgPool2D {
arg,
kernel_size,
@@ -690,30 +714,38 @@ impl Tensor {
}
}
+/// A store for gradients, associating a tensor id to the corresponding gradient tensor, used for back propagation.
#[derive(Debug)]
pub struct GradStore(HashMap<TensorId, Tensor>);
impl GradStore {
+ /// Create a new gradient store
fn new() -> Self {
GradStore(HashMap::new())
}
+ /// Get the gradient tensor corresponding to the given tensor id
pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
self.0.get(&id)
}
+ /// Get the gradient tensor associated with the given tensor
pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
self.0.get(&tensor.id())
}
+ /// Remove the gradient tensor associated with the given tensor, returning it if it exists
pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
self.0.remove(&tensor.id())
}
+ /// Insert a gradient tensor associated with the given tensor, returning the previous gradient tensor if it existed
pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
self.0.insert(tensor.id(), grad)
}
+ /// Get the gradient tensor associated with the given tensor, or, if it does not exist,
+ /// insert a tensor of zeroes, with the same shape and type as the given tensors and return it
fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
use std::collections::hash_map::Entry;
let grad = match self.0.entry(tensor.id()) {