diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-23 12:58:55 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-23 12:58:55 +0100 |
commit | aba1e90797e430f28eec13b14b76dd5355876f9c (patch) | |
tree | 16bcf7fb151715d3bcdbec2b5263922bd0bdd35a /candle-nn | |
parent | 4ee1cf038ada55ec477dcd6496cf2aec1902775b (diff) | |
download | candle-aba1e90797e430f28eec13b14b76dd5355876f9c.tar.gz candle-aba1e90797e430f28eec13b14b76dd5355876f9c.tar.bz2 candle-aba1e90797e430f28eec13b14b76dd5355876f9c.zip |
Add some group parameter to convolutions. (#566)
* Add some group parameter to convolutions.
* Avoid some unnecessary groups checks.
* Move the tensor convolution bits.
* Properh handling of groups.
* Bump the crate version.
* And add a changelog.
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/Cargo.toml | 2 | ||||
-rw-r--r-- | candle-nn/src/conv.rs | 18 |
2 files changed, 17 insertions, 3 deletions
diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index b3e9c0bf..7cd1d7a2 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -11,7 +11,7 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" } +candle = { path = "../candle-core", version = "0.1.3", package = "candle-core" } thiserror = { workspace = true } intel-mkl-src = { workspace = true, optional = true } safetensors = { workspace = true } diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index df9818ab..204402c3 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -5,6 +5,7 @@ use candle::{Result, Tensor}; pub struct Conv1dConfig { pub padding: usize, pub stride: usize, + pub groups: usize, } impl Default for Conv1dConfig { @@ -12,6 +13,7 @@ impl Default for Conv1dConfig { Self { padding: 0, stride: 1, + groups: 1, } } } @@ -39,7 +41,12 @@ impl Conv1d { impl crate::Module for Conv1d { fn forward(&self, x: &Tensor) -> Result<Tensor> { - let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?; + let x = x.conv1d( + &self.weight, + self.config.padding, + self.config.stride, + self.config.groups, + )?; match &self.bias { None => Ok(x), Some(bias) => { @@ -55,6 +62,7 @@ impl crate::Module for Conv1d { pub struct Conv2dConfig { pub padding: usize, pub stride: usize, + pub groups: usize, } impl Default for Conv2dConfig { @@ -62,6 +70,7 @@ impl Default for Conv2dConfig { Self { padding: 0, stride: 1, + groups: 1, } } } @@ -90,7 +99,12 @@ impl Conv2d { impl crate::Module for Conv2d { fn forward(&self, x: &Tensor) -> Result<Tensor> { - let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?; + let x = x.conv2d( + &self.weight, + self.config.padding, + self.config.stride, + self.config.groups, + )?; match &self.bias { None => Ok(x), Some(bias) => { |