diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-23 18:02:58 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-23 18:02:58 +0100 |
commit | 431051cc327c22733ecf20c80bcbf1760444a751 (patch) | |
tree | cc92098ed3cb065b8c173c8335eb459ee7589236 /candle-core/src/conv.rs | |
parent | eedd85ffa76bbeafc2871d5a8782d821d11f24fa (diff) | |
download | candle-431051cc327c22733ecf20c80bcbf1760444a751.tar.gz candle-431051cc327c22733ecf20c80bcbf1760444a751.tar.bz2 candle-431051cc327c22733ecf20c80bcbf1760444a751.zip |
Add Efficientnet (#572)
* EfficientNet.
* Complete the efficientnet implementation.
* Improve group handling.
* Get the efficientnet to work.
Diffstat (limited to 'candle-core/src/conv.rs')
-rw-r--r-- | candle-core/src/conv.rs | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index d4b7a76d..77d4c5cd 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -93,8 +93,8 @@ impl Tensor { let params = ParamsConv1D { b_size, l_in, - c_out, - c_in, + c_out: c_out / groups, + c_in: c_in / groups, k_size, padding, stride, @@ -103,9 +103,11 @@ impl Tensor { self.conv1d_single_group(kernel, ¶ms) } else { let blocks = self.chunk(groups, 1)?; + let kernel = kernel.chunk(groups, 0)?; let blocks = blocks .iter() - .map(|block| block.conv1d_single_group(kernel, ¶ms)) + .zip(&kernel) + .map(|(block, kernel)| block.conv1d_single_group(kernel, ¶ms)) .collect::<Result<Vec<_>>>()?; Tensor::cat(&blocks, 1) } @@ -146,8 +148,8 @@ impl Tensor { i_w, k_h, k_w, - c_out, - c_in, + c_out: c_out / groups, + c_in: c_in / groups, padding, stride, }; @@ -155,9 +157,11 @@ impl Tensor { self.conv2d_single_group(kernel, ¶ms) } else { let blocks = self.chunk(groups, 1)?; + let kernel = kernel.chunk(groups, 0)?; let blocks = blocks .iter() - .map(|block| block.conv2d_single_group(kernel, ¶ms)) + .zip(&kernel) + .map(|(block, kernel)| block.conv2d_single_group(kernel, ¶ms)) .collect::<Result<Vec<_>>>()?; Tensor::cat(&blocks, 1) } |