diff options
Diffstat (limited to 'candle-nn/src/conv.rs')
-rw-r--r-- | candle-nn/src/conv.rs | 18 |
1 files changed, 16 insertions, 2 deletions
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index df9818ab..204402c3 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -5,6 +5,7 @@ use candle::{Result, Tensor}; pub struct Conv1dConfig { pub padding: usize, pub stride: usize, + pub groups: usize, } impl Default for Conv1dConfig { @@ -12,6 +13,7 @@ impl Default for Conv1dConfig { Self { padding: 0, stride: 1, + groups: 1, } } } @@ -39,7 +41,12 @@ impl Conv1d { impl crate::Module for Conv1d { fn forward(&self, x: &Tensor) -> Result<Tensor> { - let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?; + let x = x.conv1d( + &self.weight, + self.config.padding, + self.config.stride, + self.config.groups, + )?; match &self.bias { None => Ok(x), Some(bias) => { @@ -55,6 +62,7 @@ impl crate::Module for Conv1d { pub struct Conv2dConfig { pub padding: usize, pub stride: usize, + pub groups: usize, } impl Default for Conv2dConfig { @@ -62,6 +70,7 @@ impl Default for Conv2dConfig { Self { padding: 0, stride: 1, + groups: 1, } } } @@ -90,7 +99,12 @@ impl Conv2d { impl crate::Module for Conv2d { fn forward(&self, x: &Tensor) -> Result<Tensor> { - let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?; + let x = x.conv2d( + &self.weight, + self.config.padding, + self.config.stride, + self.config.groups, + )?; match &self.bias { None => Ok(x), Some(bias) => { |