diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-02-18 21:28:07 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-18 21:28:07 +0100 |
commit | 1fb728772d603e2ca5195eb8b123d5fc77c62fed (patch) | |
tree | 6d1f74e304e045f02fe56c022184d0acb6d0f00d /candle-nn | |
parent | cb86b0c82c6f0ddf2a42c677d44f31e5da41751b (diff) | |
download | candle-1fb728772d603e2ca5195eb8b123d5fc77c62fed.tar.gz candle-1fb728772d603e2ca5195eb8b123d5fc77c62fed.tar.bz2 candle-1fb728772d603e2ca5195eb8b123d5fc77c62fed.zip |
Support for groups in conv-transpose1d. (#1731)
* Groups support in conv-transpose-1d.
* Remove dangling file.
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/conv.rs | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index cc9273ca..03b69bbd 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -76,7 +76,7 @@ pub struct ConvTranspose1dConfig { pub output_padding: usize, pub stride: usize, pub dilation: usize, - // TODO: support groups. + pub groups: usize, } impl Default for ConvTranspose1dConfig { @@ -86,6 +86,7 @@ impl Default for ConvTranspose1dConfig { output_padding: 0, stride: 1, dilation: 1, + groups: 1, } } } @@ -127,6 +128,7 @@ impl crate::Module for ConvTranspose1d { self.config.output_padding, self.config.stride, self.config.dilation, + self.config.groups, )?; match &self.bias { None => Ok(x), @@ -346,7 +348,11 @@ pub fn conv_transpose1d( lo: -bound, up: bound, }; - let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?; + let ws = vb.get_with_hints( + (in_channels, out_channels / cfg.groups, kernel_size), + "weight", + init, + )?; let bs = vb.get_with_hints(out_channels, "bias", init)?; Ok(ConvTranspose1d::new(ws, Some(bs), cfg)) } @@ -363,7 +369,11 @@ pub fn conv_transpose1d_no_bias( lo: -bound, up: bound, }; - let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?; + let ws = vb.get_with_hints( + (in_channels, out_channels / cfg.groups, kernel_size), + "weight", + init, + )?; Ok(ConvTranspose1d::new(ws, None, cfg)) } |