summaryrefslogtreecommitdiff
path: root/candle-core/src/conv.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-23 12:58:55 +0100
committerGitHub <noreply@github.com>2023-08-23 12:58:55 +0100
commitaba1e90797e430f28eec13b14b76dd5355876f9c (patch)
tree16bcf7fb151715d3bcdbec2b5263922bd0bdd35a /candle-core/src/conv.rs
parent4ee1cf038ada55ec477dcd6496cf2aec1902775b (diff)
downloadcandle-aba1e90797e430f28eec13b14b76dd5355876f9c.tar.gz
candle-aba1e90797e430f28eec13b14b76dd5355876f9c.tar.bz2
candle-aba1e90797e430f28eec13b14b76dd5355876f9c.zip
Add some group parameter to convolutions. (#566)
* Add some group parameter to convolutions. * Avoid some unnecessary groups checks. * Move the tensor convolution bits. * Properh handling of groups. * Bump the crate version. * And add a changelog.
Diffstat (limited to 'candle-core/src/conv.rs')
-rw-r--r--candle-core/src/conv.rs112
1 files changed, 112 insertions, 0 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index e3fea861..d4b7a76d 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -1,3 +1,5 @@
+use crate::{op::BackpropOp, op::Op, Error, Result, Tensor};
+
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv1D {
pub(crate) b_size: usize,
@@ -51,3 +53,113 @@ impl ParamsConv2D {
vec![self.b_size, self.c_out, self.out_h(), self.out_w()]
}
}
+
+impl Tensor {
+ fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> {
+ let storage =
+ self.storage()
+ .conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
+ let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
+ arg,
+ kernel,
+ padding: params.padding,
+ stride: params.stride,
+ });
+ let out_dims = params.out_dims();
+ Ok(crate::tensor::from_storage(storage, out_dims, op, false))
+ }
+
+ /// Applies a 1D convolution over the input tensor.
+ pub fn conv1d(
+ &self,
+ kernel: &Self,
+ padding: usize,
+ stride: usize,
+ groups: usize,
+ ) -> Result<Self> {
+ let (c_out, c_in_k, k_size) = kernel.dims3()?;
+ let (b_size, c_in, l_in) = self.dims3()?;
+ if c_in != c_in_k * groups {
+ Err(Error::Conv1dInvalidArgs {
+ inp_shape: self.shape().clone(),
+ k_shape: kernel.shape().clone(),
+ padding,
+ stride,
+ msg: "the number of in-channels on the input doesn't match the kernel size",
+ }
+ .bt())?
+ }
+
+ let params = ParamsConv1D {
+ b_size,
+ l_in,
+ c_out,
+ c_in,
+ k_size,
+ padding,
+ stride,
+ };
+ if groups == 1 {
+ self.conv1d_single_group(kernel, &params)
+ } else {
+ let blocks = self.chunk(groups, 1)?;
+ let blocks = blocks
+ .iter()
+ .map(|block| block.conv1d_single_group(kernel, &params))
+ .collect::<Result<Vec<_>>>()?;
+ Tensor::cat(&blocks, 1)
+ }
+ }
+
+ fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
+ let storage =
+ self.storage()
+ .conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?;
+ let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
+ arg,
+ kernel,
+ padding: params.padding,
+ stride: params.stride,
+ });
+ let out_dims = params.out_dims();
+ Ok(crate::tensor::from_storage(storage, out_dims, op, false))
+ }
+
+ /// Applies a 2D convolution over the input tensor.
+ pub fn conv2d(
+ &self,
+ kernel: &Self,
+ padding: usize,
+ stride: usize,
+ groups: usize,
+ ) -> Result<Self> {
+ let (b_size, c_in, i_h, i_w) = self.dims4()?;
+ let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
+ if c_in != c_in_k * groups {
+ crate::bail!(
+ "in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})"
+ )
+ }
+ let params = ParamsConv2D {
+ b_size,
+ i_h,
+ i_w,
+ k_h,
+ k_w,
+ c_out,
+ c_in,
+ padding,
+ stride,
+ };
+ if groups == 1 {
+ self.conv2d_single_group(kernel, &params)
+ } else {
+ let blocks = self.chunk(groups, 1)?;
+ let blocks = blocks
+ .iter()
+ .map(|block| block.conv2d_single_group(kernel, &params))
+ .collect::<Result<Vec<_>>>()?;
+ Tensor::cat(&blocks, 1)
+ }
+ }
+}