summaryrefslogtreecommitdiff
path: root/candle-nn/src/conv.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src/conv.rs')
-rw-r--r--candle-nn/src/conv.rs51
1 files changed, 50 insertions, 1 deletions
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs
index dbf23aa5..f985cfd6 100644
--- a/candle-nn/src/conv.rs
+++ b/candle-nn/src/conv.rs
@@ -80,7 +80,6 @@ impl Default for Conv2dConfig {
}
}
-#[allow(dead_code)]
#[derive(Debug)]
pub struct Conv2d {
weight: Tensor,
@@ -122,6 +121,56 @@ impl crate::Module for Conv2d {
}
}
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub struct ConvTranspose2dConfig {
+ pub padding: usize,
+ pub output_padding: usize,
+ pub stride: usize,
+ pub dilation: usize,
+ // TODO: support groups.
+}
+
+#[derive(Debug)]
+pub struct ConvTranspose2d {
+ weight: Tensor,
+ bias: Option<Tensor>,
+ config: ConvTranspose2dConfig,
+}
+
+impl ConvTranspose2d {
+ pub fn new(weight: Tensor, bias: Option<Tensor>, config: ConvTranspose2dConfig) -> Self {
+ Self {
+ weight,
+ bias,
+ config,
+ }
+ }
+
+ pub fn config(&self) -> &ConvTranspose2dConfig {
+ &self.config
+ }
+}
+
+impl crate::Module for ConvTranspose2d {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let x = x.conv_transpose2d(
+ &self.weight,
+ self.config.padding,
+ self.config.output_padding,
+ self.config.stride,
+ self.config.dilation,
+ )?;
+ match &self.bias {
+ None => Ok(x),
+ Some(bias) => {
+ let b = bias.dims1()?;
+ let bias = bias.reshape((1, b, 1, 1))?;
+ Ok(x.broadcast_add(&bias)?)
+ }
+ }
+ }
+}
+
pub fn conv1d(
in_channels: usize,
out_channels: usize,