summaryrefslogtreecommitdiff
path: root/candle-core/src/conv.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-23 18:02:58 +0100
committerGitHub <noreply@github.com>2023-08-23 18:02:58 +0100
commit431051cc327c22733ecf20c80bcbf1760444a751 (patch)
treecc92098ed3cb065b8c173c8335eb459ee7589236 /candle-core/src/conv.rs
parenteedd85ffa76bbeafc2871d5a8782d821d11f24fa (diff)
downloadcandle-431051cc327c22733ecf20c80bcbf1760444a751.tar.gz
candle-431051cc327c22733ecf20c80bcbf1760444a751.tar.bz2
candle-431051cc327c22733ecf20c80bcbf1760444a751.zip
Add Efficientnet (#572)
* EfficientNet. * Complete the efficientnet implementation. * Improve group handling. * Get the efficientnet to work.
Diffstat (limited to 'candle-core/src/conv.rs')
-rw-r--r--candle-core/src/conv.rs16
1 files changed, 10 insertions, 6 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index d4b7a76d..77d4c5cd 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -93,8 +93,8 @@ impl Tensor {
let params = ParamsConv1D {
b_size,
l_in,
- c_out,
- c_in,
+ c_out: c_out / groups,
+ c_in: c_in / groups,
k_size,
padding,
stride,
@@ -103,9 +103,11 @@ impl Tensor {
self.conv1d_single_group(kernel, &params)
} else {
let blocks = self.chunk(groups, 1)?;
+ let kernel = kernel.chunk(groups, 0)?;
let blocks = blocks
.iter()
- .map(|block| block.conv1d_single_group(kernel, &params))
+ .zip(&kernel)
+ .map(|(block, kernel)| block.conv1d_single_group(kernel, &params))
.collect::<Result<Vec<_>>>()?;
Tensor::cat(&blocks, 1)
}
@@ -146,8 +148,8 @@ impl Tensor {
i_w,
k_h,
k_w,
- c_out,
- c_in,
+ c_out: c_out / groups,
+ c_in: c_in / groups,
padding,
stride,
};
@@ -155,9 +157,11 @@ impl Tensor {
self.conv2d_single_group(kernel, &params)
} else {
let blocks = self.chunk(groups, 1)?;
+ let kernel = kernel.chunk(groups, 0)?;
let blocks = blocks
.iter()
- .map(|block| block.conv2d_single_group(kernel, &params))
+ .zip(&kernel)
+ .map(|(block, kernel)| block.conv2d_single_group(kernel, &params))
.collect::<Result<Vec<_>>>()?;
Tensor::cat(&blocks, 1)
}