diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-23 12:58:55 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-23 12:58:55 +0100 |
commit | aba1e90797e430f28eec13b14b76dd5355876f9c (patch) | |
tree | 16bcf7fb151715d3bcdbec2b5263922bd0bdd35a /candle-core/src/conv.rs | |
parent | 4ee1cf038ada55ec477dcd6496cf2aec1902775b (diff) | |
download | candle-aba1e90797e430f28eec13b14b76dd5355876f9c.tar.gz candle-aba1e90797e430f28eec13b14b76dd5355876f9c.tar.bz2 candle-aba1e90797e430f28eec13b14b76dd5355876f9c.zip |
Add some group parameter to convolutions. (#566)
* Add some group parameter to convolutions.
* Avoid some unnecessary groups checks.
* Move the tensor convolution bits.
* Properh handling of groups.
* Bump the crate version.
* And add a changelog.
Diffstat (limited to 'candle-core/src/conv.rs')
-rw-r--r-- | candle-core/src/conv.rs | 112 |
1 files changed, 112 insertions, 0 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index e3fea861..d4b7a76d 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -1,3 +1,5 @@ +use crate::{op::BackpropOp, op::Op, Error, Result, Tensor}; + #[derive(Debug, Clone, PartialEq, Eq)] pub struct ParamsConv1D { pub(crate) b_size: usize, @@ -51,3 +53,113 @@ impl ParamsConv2D { vec![self.b_size, self.c_out, self.out_h(), self.out_w()] } } + +impl Tensor { + fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> { + let storage = + self.storage() + .conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?; + let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D { + arg, + kernel, + padding: params.padding, + stride: params.stride, + }); + let out_dims = params.out_dims(); + Ok(crate::tensor::from_storage(storage, out_dims, op, false)) + } + + /// Applies a 1D convolution over the input tensor. + pub fn conv1d( + &self, + kernel: &Self, + padding: usize, + stride: usize, + groups: usize, + ) -> Result<Self> { + let (c_out, c_in_k, k_size) = kernel.dims3()?; + let (b_size, c_in, l_in) = self.dims3()?; + if c_in != c_in_k * groups { + Err(Error::Conv1dInvalidArgs { + inp_shape: self.shape().clone(), + k_shape: kernel.shape().clone(), + padding, + stride, + msg: "the number of in-channels on the input doesn't match the kernel size", + } + .bt())? + } + + let params = ParamsConv1D { + b_size, + l_in, + c_out, + c_in, + k_size, + padding, + stride, + }; + if groups == 1 { + self.conv1d_single_group(kernel, ¶ms) + } else { + let blocks = self.chunk(groups, 1)?; + let blocks = blocks + .iter() + .map(|block| block.conv1d_single_group(kernel, ¶ms)) + .collect::<Result<Vec<_>>>()?; + Tensor::cat(&blocks, 1) + } + } + + fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> { + let storage = + self.storage() + .conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?; + let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D { + arg, + kernel, + padding: params.padding, + stride: params.stride, + }); + let out_dims = params.out_dims(); + Ok(crate::tensor::from_storage(storage, out_dims, op, false)) + } + + /// Applies a 2D convolution over the input tensor. + pub fn conv2d( + &self, + kernel: &Self, + padding: usize, + stride: usize, + groups: usize, + ) -> Result<Self> { + let (b_size, c_in, i_h, i_w) = self.dims4()?; + let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?; + if c_in != c_in_k * groups { + crate::bail!( + "in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})" + ) + } + let params = ParamsConv2D { + b_size, + i_h, + i_w, + k_h, + k_w, + c_out, + c_in, + padding, + stride, + }; + if groups == 1 { + self.conv2d_single_group(kernel, ¶ms) + } else { + let blocks = self.chunk(groups, 1)?; + let blocks = blocks + .iter() + .map(|block| block.conv2d_single_group(kernel, ¶ms)) + .collect::<Result<Vec<_>>>()?; + Tensor::cat(&blocks, 1) + } + } +} |