From 1fb728772d603e2ca5195eb8b123d5fc77c62fed Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 18 Feb 2024 21:28:07 +0100 Subject: Support for groups in conv-transpose1d. (#1731) * Groups support in conv-transpose-1d. * Remove dangling file. --- candle-nn/src/conv.rs | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) (limited to 'candle-nn') diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index cc9273ca..03b69bbd 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -76,7 +76,7 @@ pub struct ConvTranspose1dConfig { pub output_padding: usize, pub stride: usize, pub dilation: usize, - // TODO: support groups. + pub groups: usize, } impl Default for ConvTranspose1dConfig { @@ -86,6 +86,7 @@ impl Default for ConvTranspose1dConfig { output_padding: 0, stride: 1, dilation: 1, + groups: 1, } } } @@ -127,6 +128,7 @@ impl crate::Module for ConvTranspose1d { self.config.output_padding, self.config.stride, self.config.dilation, + self.config.groups, )?; match &self.bias { None => Ok(x), @@ -346,7 +348,11 @@ pub fn conv_transpose1d( lo: -bound, up: bound, }; - let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?; + let ws = vb.get_with_hints( + (in_channels, out_channels / cfg.groups, kernel_size), + "weight", + init, + )?; let bs = vb.get_with_hints(out_channels, "bias", init)?; Ok(ConvTranspose1d::new(ws, Some(bs), cfg)) } @@ -363,7 +369,11 @@ pub fn conv_transpose1d_no_bias( lo: -bound, up: bound, }; - let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?; + let ws = vb.get_with_hints( + (in_channels, out_channels / cfg.groups, kernel_size), + "weight", + init, + )?; Ok(ConvTranspose1d::new(ws, None, cfg)) } -- cgit v1.2.3