summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-11-03 11:18:25 +0100
committerGitHub <noreply@github.com>2023-11-03 11:18:25 +0100
commit3b0d1e7d03469ed17d1eec931bd76c857b99ff3a (patch)
treeda615832f4692ae6a25294b6dff0129c872ad5f2 /candle-nn
parentbe4555c5a5ea43c12cf2db68546590e649486c2b (diff)
downloadcandle-3b0d1e7d03469ed17d1eec931bd76c857b99ff3a.tar.gz
candle-3b0d1e7d03469ed17d1eec931bd76c857b99ff3a.tar.bz2
candle-3b0d1e7d03469ed17d1eec931bd76c857b99ff3a.zip
Transposed conv1d in candle-nn. (#1252)
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/conv.rs94
1 files changed, 94 insertions, 0 deletions
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs
index 7c0bf841..b1168405 100644
--- a/candle-nn/src/conv.rs
+++ b/candle-nn/src/conv.rs
@@ -71,6 +71,67 @@ impl crate::Module for Conv1d {
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub struct ConvTranspose1dConfig {
+ pub padding: usize,
+ pub output_padding: usize,
+ pub stride: usize,
+ pub dilation: usize,
+ // TODO: support groups.
+}
+
+impl Default for ConvTranspose1dConfig {
+ fn default() -> Self {
+ Self {
+ padding: 0,
+ output_padding: 0,
+ stride: 1,
+ dilation: 1,
+ }
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct ConvTranspose1d {
+ weight: Tensor,
+ bias: Option<Tensor>,
+ config: ConvTranspose1dConfig,
+}
+
+impl ConvTranspose1d {
+ pub fn new(weight: Tensor, bias: Option<Tensor>, config: ConvTranspose1dConfig) -> Self {
+ Self {
+ weight,
+ bias,
+ config,
+ }
+ }
+
+ pub fn config(&self) -> &ConvTranspose1dConfig {
+ &self.config
+ }
+}
+
+impl crate::Module for ConvTranspose1d {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let x = x.conv_transpose1d(
+ &self.weight,
+ self.config.padding,
+ self.config.output_padding,
+ self.config.stride,
+ self.config.dilation,
+ )?;
+ match &self.bias {
+ None => Ok(x),
+ Some(bias) => {
+ let b = bias.dims1()?;
+ let bias = bias.reshape((1, b, 1, 1))?;
+ Ok(x.broadcast_add(&bias)?)
+ }
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Conv2dConfig {
pub padding: usize,
pub stride: usize,
@@ -241,6 +302,39 @@ pub fn conv1d(
Ok(Conv1d::new(ws, Some(bs), cfg))
}
+pub fn conv_transpose1d(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ cfg: ConvTranspose1dConfig,
+ vb: crate::VarBuilder,
+) -> Result<ConvTranspose1d> {
+ let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt();
+ let init = crate::Init::Uniform {
+ lo: -bound,
+ up: bound,
+ };
+ let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
+ let bs = vb.get_with_hints(out_channels, "bias", init)?;
+ Ok(ConvTranspose1d::new(ws, Some(bs), cfg))
+}
+
+pub fn conv_transpose1d_no_bias(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ cfg: ConvTranspose1dConfig,
+ vb: crate::VarBuilder,
+) -> Result<ConvTranspose1d> {
+ let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt();
+ let init = crate::Init::Uniform {
+ lo: -bound,
+ up: bound,
+ };
+ let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
+ Ok(ConvTranspose1d::new(ws, None, cfg))
+}
+
pub fn conv2d(
in_channels: usize,
out_channels: usize,