summaryrefslogtreecommitdiff
path: root/candle-nn/src/conv.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src/conv.rs')
-rw-r--r--candle-nn/src/conv.rs18
1 files changed, 16 insertions, 2 deletions
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs
index df9818ab..204402c3 100644
--- a/candle-nn/src/conv.rs
+++ b/candle-nn/src/conv.rs
@@ -5,6 +5,7 @@ use candle::{Result, Tensor};
pub struct Conv1dConfig {
pub padding: usize,
pub stride: usize,
+ pub groups: usize,
}
impl Default for Conv1dConfig {
@@ -12,6 +13,7 @@ impl Default for Conv1dConfig {
Self {
padding: 0,
stride: 1,
+ groups: 1,
}
}
}
@@ -39,7 +41,12 @@ impl Conv1d {
impl crate::Module for Conv1d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?;
+ let x = x.conv1d(
+ &self.weight,
+ self.config.padding,
+ self.config.stride,
+ self.config.groups,
+ )?;
match &self.bias {
None => Ok(x),
Some(bias) => {
@@ -55,6 +62,7 @@ impl crate::Module for Conv1d {
pub struct Conv2dConfig {
pub padding: usize,
pub stride: usize,
+ pub groups: usize,
}
impl Default for Conv2dConfig {
@@ -62,6 +70,7 @@ impl Default for Conv2dConfig {
Self {
padding: 0,
stride: 1,
+ groups: 1,
}
}
}
@@ -90,7 +99,12 @@ impl Conv2d {
impl crate::Module for Conv2d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?;
+ let x = x.conv2d(
+ &self.weight,
+ self.config.padding,
+ self.config.stride,
+ self.config.groups,
+ )?;
match &self.bias {
None => Ok(x),
Some(bias) => {