diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-11-03 11:18:25 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-03 11:18:25 +0100 |
commit | 3b0d1e7d03469ed17d1eec931bd76c857b99ff3a (patch) | |
tree | da615832f4692ae6a25294b6dff0129c872ad5f2 /candle-nn | |
parent | be4555c5a5ea43c12cf2db68546590e649486c2b (diff) | |
download | candle-3b0d1e7d03469ed17d1eec931bd76c857b99ff3a.tar.gz candle-3b0d1e7d03469ed17d1eec931bd76c857b99ff3a.tar.bz2 candle-3b0d1e7d03469ed17d1eec931bd76c857b99ff3a.zip |
Transposed conv1d in candle-nn. (#1252)
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/conv.rs | 94 |
1 files changed, 94 insertions, 0 deletions
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index 7c0bf841..b1168405 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -71,6 +71,67 @@ impl crate::Module for Conv1d { } #[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ConvTranspose1dConfig { + pub padding: usize, + pub output_padding: usize, + pub stride: usize, + pub dilation: usize, + // TODO: support groups. +} + +impl Default for ConvTranspose1dConfig { + fn default() -> Self { + Self { + padding: 0, + output_padding: 0, + stride: 1, + dilation: 1, + } + } +} + +#[derive(Clone, Debug)] +pub struct ConvTranspose1d { + weight: Tensor, + bias: Option<Tensor>, + config: ConvTranspose1dConfig, +} + +impl ConvTranspose1d { + pub fn new(weight: Tensor, bias: Option<Tensor>, config: ConvTranspose1dConfig) -> Self { + Self { + weight, + bias, + config, + } + } + + pub fn config(&self) -> &ConvTranspose1dConfig { + &self.config + } +} + +impl crate::Module for ConvTranspose1d { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let x = x.conv_transpose1d( + &self.weight, + self.config.padding, + self.config.output_padding, + self.config.stride, + self.config.dilation, + )?; + match &self.bias { + None => Ok(x), + Some(bias) => { + let b = bias.dims1()?; + let bias = bias.reshape((1, b, 1, 1))?; + Ok(x.broadcast_add(&bias)?) + } + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Conv2dConfig { pub padding: usize, pub stride: usize, @@ -241,6 +302,39 @@ pub fn conv1d( Ok(Conv1d::new(ws, Some(bs), cfg)) } +pub fn conv_transpose1d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: ConvTranspose1dConfig, + vb: crate::VarBuilder, +) -> Result<ConvTranspose1d> { + let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt(); + let init = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?; + let bs = vb.get_with_hints(out_channels, "bias", init)?; + Ok(ConvTranspose1d::new(ws, Some(bs), cfg)) +} + +pub fn conv_transpose1d_no_bias( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + cfg: ConvTranspose1dConfig, + vb: crate::VarBuilder, +) -> Result<ConvTranspose1d> { + let bound = 1. / (out_channels as f64 * kernel_size as f64).sqrt(); + let init = crate::Init::Uniform { + lo: -bound, + up: bound, + }; + let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?; + Ok(ConvTranspose1d::new(ws, None, cfg)) +} + pub fn conv2d( in_channels: usize, out_channels: usize, |