diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-02-18 21:28:07 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-18 21:28:07 +0100 |
commit | 1fb728772d603e2ca5195eb8b123d5fc77c62fed (patch) | |
tree | 6d1f74e304e045f02fe56c022184d0acb6d0f00d | |
parent | cb86b0c82c6f0ddf2a42c677d44f31e5da41751b (diff) | |
download | candle-1fb728772d603e2ca5195eb8b123d5fc77c62fed.tar.gz candle-1fb728772d603e2ca5195eb8b123d5fc77c62fed.tar.bz2 candle-1fb728772d603e2ca5195eb8b123d5fc77c62fed.zip |
Support for groups in conv-transpose1d. (#1731)
* Groups support in conv-transpose-1d.
* Remove dangling file.
-rw-r--r-- | candle-core/src/backprop.rs | 1 | ||||
-rw-r--r-- | candle-core/src/conv.rs | 59 | ||||
-rw-r--r-- | candle-core/tests/conv_tests.rs | 2 | ||||
-rw-r--r-- | candle-nn/src/conv.rs | 16 |
4 files changed, 56 insertions, 22 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 26d73ea1..35619015 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -250,6 +250,7 @@ impl Tensor { out_padding, *stride, *dilation, + /* groups */ 1, )?; let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad_arg)?; diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index fe923087..7b3922dd 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -187,6 +187,29 @@ impl Tensor { } } + fn conv_transpose1d_single_group( + &self, + kernel: &Self, + params: &ParamsConvTranspose1D, + ) -> Result<Self> { + let storage = self.storage().conv_transpose1d( + self.layout(), + &kernel.storage(), + kernel.layout(), + params, + )?; + let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D { + arg, + kernel, + padding: params.padding, + output_padding: params.output_padding, + stride: params.stride, + dilation: params.dilation, + }); + let out_dims = params.out_dims(); + Ok(crate::tensor::from_storage(storage, out_dims, op, false)) + } + /// Applies a 1D transposed convolution over the input tensor. pub fn conv_transpose1d( &self, @@ -195,39 +218,39 @@ impl Tensor { output_padding: usize, stride: usize, dilation: usize, + groups: usize, ) -> Result<Self> { - let (b_size, c_in, l_in) = self.dims3()?; let (c_in_k, c_out, k_size) = kernel.dims3()?; + let (b_size, c_in, l_in) = self.dims3()?; if c_in != c_in_k { crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})") } + if c_in % groups != 0 { + crate::bail!("in_channel {c_in} is not divisible by the number of groups") + } let params = ParamsConvTranspose1D { b_size, l_in, k_size, c_out, - c_in, + c_in: c_in / groups, padding, output_padding, stride, dilation, }; - let storage = self.storage().conv_transpose1d( - self.layout(), - &kernel.storage(), - kernel.layout(), - ¶ms, - )?; - let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D { - arg, - kernel, - padding: params.padding, - output_padding: params.output_padding, - stride: params.stride, - dilation: params.dilation, - }); - let out_dims = params.out_dims(); - Ok(crate::tensor::from_storage(storage, out_dims, op, false)) + if groups == 1 { + self.conv_transpose1d_single_group(kernel, ¶ms) + } else { + let blocks = self.chunk(groups, 1)?; + let kernel = kernel.chunk(groups, 0)?; + let blocks = blocks + .iter() + .zip(&kernel) + .map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, ¶ms)) + .collect::<Result<Vec<_>>>()?; + Tensor::cat(&blocks, 1) + } } fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> { diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index 5bbd903d..211a1fe0 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -50,7 +50,7 @@ fn conv1d(dev: &Device) -> Result<()> { test_utils::to_vec1_round(&res.flatten_all()?, 4)?, [2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352] ); - let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; + let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?; assert_eq!(res.dims(), [1, 2, 7]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index cc9273ca..03b69bbd 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -76,7 +76,7 @@ pub struct ConvTranspose1dConfig { pub output_padding: usize, pub stride: usize, pub dilation: usize, - // TODO: support groups. + pub groups: usize, } impl Default for ConvTranspose1dConfig { @@ -86,6 +86,7 @@ impl Default for ConvTranspose1dConfig { output_padding: 0, stride: 1, dilation: 1, + groups: 1, } } } @@ -127,6 +128,7 @@ impl crate::Module for ConvTranspose1d { self.config.output_padding, self.config.stride, self.config.dilation, + self.config.groups, )?; match &self.bias { None => Ok(x), @@ -346,7 +348,11 @@ pub fn conv_transpose1d( lo: -bound, up: bound, }; - let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?; + let ws = vb.get_with_hints( + (in_channels, out_channels / cfg.groups, kernel_size), + "weight", + init, + )?; let bs = vb.get_with_hints(out_channels, "bias", init)?; Ok(ConvTranspose1d::new(ws, Some(bs), cfg)) } @@ -363,7 +369,11 @@ pub fn conv_transpose1d_no_bias( lo: -bound, up: bound, }; - let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?; + let ws = vb.get_with_hints( + (in_channels, out_channels / cfg.groups, kernel_size), + "weight", + init, + )?; Ok(ConvTranspose1d::new(ws, None, cfg)) } |