diff options
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/src/conv.rs | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index d4b7a76d..77d4c5cd 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -93,8 +93,8 @@ impl Tensor { let params = ParamsConv1D { b_size, l_in, - c_out, - c_in, + c_out: c_out / groups, + c_in: c_in / groups, k_size, padding, stride, @@ -103,9 +103,11 @@ impl Tensor { self.conv1d_single_group(kernel, ¶ms) } else { let blocks = self.chunk(groups, 1)?; + let kernel = kernel.chunk(groups, 0)?; let blocks = blocks .iter() - .map(|block| block.conv1d_single_group(kernel, ¶ms)) + .zip(&kernel) + .map(|(block, kernel)| block.conv1d_single_group(kernel, ¶ms)) .collect::<Result<Vec<_>>>()?; Tensor::cat(&blocks, 1) } @@ -146,8 +148,8 @@ impl Tensor { i_w, k_h, k_w, - c_out, - c_in, + c_out: c_out / groups, + c_in: c_in / groups, padding, stride, }; @@ -155,9 +157,11 @@ impl Tensor { self.conv2d_single_group(kernel, ¶ms) } else { let blocks = self.chunk(groups, 1)?; + let kernel = kernel.chunk(groups, 0)?; let blocks = blocks .iter() - .map(|block| block.conv2d_single_group(kernel, ¶ms)) + .zip(&kernel) + .map(|(block, kernel)| block.conv2d_single_group(kernel, ¶ms)) .collect::<Result<Vec<_>>>()?; Tensor::cat(&blocks, 1) } |