summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/src/backprop.rs1
-rw-r--r--candle-core/src/conv.rs59
-rw-r--r--candle-core/tests/conv_tests.rs2
3 files changed, 43 insertions, 19 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 26d73ea1..35619015 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -250,6 +250,7 @@ impl Tensor {
out_padding,
*stride,
*dilation,
+ /* groups */ 1,
)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index fe923087..7b3922dd 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -187,6 +187,29 @@ impl Tensor {
}
}
+ fn conv_transpose1d_single_group(
+ &self,
+ kernel: &Self,
+ params: &ParamsConvTranspose1D,
+ ) -> Result<Self> {
+ let storage = self.storage().conv_transpose1d(
+ self.layout(),
+ &kernel.storage(),
+ kernel.layout(),
+ params,
+ )?;
+ let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
+ arg,
+ kernel,
+ padding: params.padding,
+ output_padding: params.output_padding,
+ stride: params.stride,
+ dilation: params.dilation,
+ });
+ let out_dims = params.out_dims();
+ Ok(crate::tensor::from_storage(storage, out_dims, op, false))
+ }
+
/// Applies a 1D transposed convolution over the input tensor.
pub fn conv_transpose1d(
&self,
@@ -195,39 +218,39 @@ impl Tensor {
output_padding: usize,
stride: usize,
dilation: usize,
+ groups: usize,
) -> Result<Self> {
- let (b_size, c_in, l_in) = self.dims3()?;
let (c_in_k, c_out, k_size) = kernel.dims3()?;
+ let (b_size, c_in, l_in) = self.dims3()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
+ if c_in % groups != 0 {
+ crate::bail!("in_channel {c_in} is not divisible by the number of groups")
+ }
let params = ParamsConvTranspose1D {
b_size,
l_in,
k_size,
c_out,
- c_in,
+ c_in: c_in / groups,
padding,
output_padding,
stride,
dilation,
};
- let storage = self.storage().conv_transpose1d(
- self.layout(),
- &kernel.storage(),
- kernel.layout(),
- &params,
- )?;
- let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
- arg,
- kernel,
- padding: params.padding,
- output_padding: params.output_padding,
- stride: params.stride,
- dilation: params.dilation,
- });
- let out_dims = params.out_dims();
- Ok(crate::tensor::from_storage(storage, out_dims, op, false))
+ if groups == 1 {
+ self.conv_transpose1d_single_group(kernel, &params)
+ } else {
+ let blocks = self.chunk(groups, 1)?;
+ let kernel = kernel.chunk(groups, 0)?;
+ let blocks = blocks
+ .iter()
+ .zip(&kernel)
+ .map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, &params))
+ .collect::<Result<Vec<_>>>()?;
+ Tensor::cat(&blocks, 1)
+ }
}
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs
index 5bbd903d..211a1fe0 100644
--- a/candle-core/tests/conv_tests.rs
+++ b/candle-core/tests/conv_tests.rs
@@ -50,7 +50,7 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
- let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
+ let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,