summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-23 12:58:55 +0100
committerGitHub <noreply@github.com>2023-08-23 12:58:55 +0100
commitaba1e90797e430f28eec13b14b76dd5355876f9c (patch)
tree16bcf7fb151715d3bcdbec2b5263922bd0bdd35a /candle-nn
parent4ee1cf038ada55ec477dcd6496cf2aec1902775b (diff)
downloadcandle-aba1e90797e430f28eec13b14b76dd5355876f9c.tar.gz
candle-aba1e90797e430f28eec13b14b76dd5355876f9c.tar.bz2
candle-aba1e90797e430f28eec13b14b76dd5355876f9c.zip
Add some group parameter to convolutions. (#566)
* Add some group parameter to convolutions. * Avoid some unnecessary groups checks. * Move the tensor convolution bits. * Properh handling of groups. * Bump the crate version. * And add a changelog.
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/Cargo.toml2
-rw-r--r--candle-nn/src/conv.rs18
2 files changed, 17 insertions, 3 deletions
diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml
index b3e9c0bf..7cd1d7a2 100644
--- a/candle-nn/Cargo.toml
+++ b/candle-nn/Cargo.toml
@@ -11,7 +11,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
-candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" }
+candle = { path = "../candle-core", version = "0.1.3", package = "candle-core" }
thiserror = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
safetensors = { workspace = true }
diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs
index df9818ab..204402c3 100644
--- a/candle-nn/src/conv.rs
+++ b/candle-nn/src/conv.rs
@@ -5,6 +5,7 @@ use candle::{Result, Tensor};
pub struct Conv1dConfig {
pub padding: usize,
pub stride: usize,
+ pub groups: usize,
}
impl Default for Conv1dConfig {
@@ -12,6 +13,7 @@ impl Default for Conv1dConfig {
Self {
padding: 0,
stride: 1,
+ groups: 1,
}
}
}
@@ -39,7 +41,12 @@ impl Conv1d {
impl crate::Module for Conv1d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?;
+ let x = x.conv1d(
+ &self.weight,
+ self.config.padding,
+ self.config.stride,
+ self.config.groups,
+ )?;
match &self.bias {
None => Ok(x),
Some(bias) => {
@@ -55,6 +62,7 @@ impl crate::Module for Conv1d {
pub struct Conv2dConfig {
pub padding: usize,
pub stride: usize,
+ pub groups: usize,
}
impl Default for Conv2dConfig {
@@ -62,6 +70,7 @@ impl Default for Conv2dConfig {
Self {
padding: 0,
stride: 1,
+ groups: 1,
}
}
}
@@ -90,7 +99,12 @@ impl Conv2d {
impl crate::Module for Conv2d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?;
+ let x = x.conv2d(
+ &self.weight,
+ self.config.padding,
+ self.config.stride,
+ self.config.groups,
+ )?;
match &self.bias {
None => Ok(x),
Some(bias) => {