diff options
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/src/backprop.rs | 1 | ||||
-rw-r--r-- | candle-core/src/conv.rs | 59 | ||||
-rw-r--r-- | candle-core/tests/conv_tests.rs | 2 |
3 files changed, 43 insertions, 19 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 26d73ea1..35619015 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -250,6 +250,7 @@ impl Tensor { out_padding, *stride, *dilation, + /* groups */ 1, )?; let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad_arg)?; diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index fe923087..7b3922dd 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -187,6 +187,29 @@ impl Tensor { } } + fn conv_transpose1d_single_group( + &self, + kernel: &Self, + params: &ParamsConvTranspose1D, + ) -> Result<Self> { + let storage = self.storage().conv_transpose1d( + self.layout(), + &kernel.storage(), + kernel.layout(), + params, + )?; + let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D { + arg, + kernel, + padding: params.padding, + output_padding: params.output_padding, + stride: params.stride, + dilation: params.dilation, + }); + let out_dims = params.out_dims(); + Ok(crate::tensor::from_storage(storage, out_dims, op, false)) + } + /// Applies a 1D transposed convolution over the input tensor. pub fn conv_transpose1d( &self, @@ -195,39 +218,39 @@ impl Tensor { output_padding: usize, stride: usize, dilation: usize, + groups: usize, ) -> Result<Self> { - let (b_size, c_in, l_in) = self.dims3()?; let (c_in_k, c_out, k_size) = kernel.dims3()?; + let (b_size, c_in, l_in) = self.dims3()?; if c_in != c_in_k { crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})") } + if c_in % groups != 0 { + crate::bail!("in_channel {c_in} is not divisible by the number of groups") + } let params = ParamsConvTranspose1D { b_size, l_in, k_size, c_out, - c_in, + c_in: c_in / groups, padding, output_padding, stride, dilation, }; - let storage = self.storage().conv_transpose1d( - self.layout(), - &kernel.storage(), - kernel.layout(), - ¶ms, - )?; - let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D { - arg, - kernel, - padding: params.padding, - output_padding: params.output_padding, - stride: params.stride, - dilation: params.dilation, - }); - let out_dims = params.out_dims(); - Ok(crate::tensor::from_storage(storage, out_dims, op, false)) + if groups == 1 { + self.conv_transpose1d_single_group(kernel, ¶ms) + } else { + let blocks = self.chunk(groups, 1)?; + let kernel = kernel.chunk(groups, 0)?; + let blocks = blocks + .iter() + .zip(&kernel) + .map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, ¶ms)) + .collect::<Result<Vec<_>>>()?; + Tensor::cat(&blocks, 1) + } } fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> { diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index 5bbd903d..211a1fe0 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -50,7 +50,7 @@ fn conv1d(dev: &Device) -> Result<()> { test_utils::to_vec1_round(&res.flatten_all()?, 4)?, [2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352] ); - let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; + let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?; assert_eq!(res.dims(), [1, 2, 7]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, |