//! 1D and 2D Convolutions //! use crate::{op::BackpropOp, op::Op, Error, Result, Tensor}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct ParamsConv1D { pub(crate) b_size: usize, // Maybe we should have a version without l_in as this bit depends on the input and not only on // the weights. pub(crate) l_in: usize, pub(crate) c_out: usize, pub(crate) c_in: usize, pub(crate) k_size: usize, pub(crate) padding: usize, pub(crate) stride: usize, pub(crate) dilation: usize, } impl ParamsConv1D { pub(crate) fn l_out(&self) -> usize { (self.l_in + 2 * self.padding - self.dilation * (self.k_size - 1) - 1) / self.stride + 1 } pub(crate) fn out_dims(&self) -> Vec { let l_out = self.l_out(); vec![self.b_size, self.c_out, l_out] } } #[derive(Debug, Clone, PartialEq, Eq)] pub struct ParamsConvTranspose1D { pub(crate) b_size: usize, pub(crate) l_in: usize, pub(crate) c_out: usize, pub(crate) c_in: usize, pub(crate) k_size: usize, pub(crate) padding: usize, pub(crate) output_padding: usize, pub(crate) stride: usize, pub(crate) dilation: usize, } impl ParamsConvTranspose1D { pub(crate) fn l_out(&self) -> usize { (self.l_in - 1) * self.stride - 2 * self.padding + self.dilation * (self.k_size - 1) + self.output_padding + 1 } pub(crate) fn out_dims(&self) -> Vec { let l_out = self.l_out(); vec![self.b_size, self.c_out, l_out] } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum CudnnFwdAlgo { ImplicitGemm, ImplicitPrecompGemm, Gemm, Direct, Fft, FftTiling, Winograd, WinogradNonFused, Count, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct ParamsConv2D { pub(crate) b_size: usize, pub(crate) i_h: usize, pub(crate) i_w: usize, pub(crate) k_h: usize, pub(crate) k_w: usize, pub(crate) c_out: usize, pub(crate) c_in: usize, pub(crate) padding: usize, pub(crate) stride: usize, pub(crate) dilation: usize, pub cudnn_fwd_algo: Option, } impl ParamsConv2D { pub(crate) fn out_h(&self) -> usize { (self.i_h + 2 * self.padding - self.dilation * (self.k_h - 1) - 1) / self.stride + 1 } pub(crate) fn out_w(&self) -> usize { (self.i_w + 2 * self.padding - self.dilation * (self.k_w - 1) - 1) / self.stride + 1 } pub(crate) fn out_dims(&self) -> Vec { vec![self.b_size, self.c_out, self.out_h(), self.out_w()] } } #[derive(Debug, Clone, PartialEq, Eq)] pub struct ParamsConvTranspose2D { pub(crate) b_size: usize, pub(crate) i_h: usize, pub(crate) i_w: usize, pub(crate) k_h: usize, pub(crate) k_w: usize, pub(crate) c_out: usize, pub(crate) c_in: usize, pub(crate) padding: usize, pub(crate) output_padding: usize, pub(crate) stride: usize, pub(crate) dilation: usize, } impl ParamsConvTranspose2D { pub(crate) fn out_h(&self) -> usize { (self.i_h - 1) * self.stride + self.dilation * (self.k_h - 1) + self.output_padding + 1 - 2 * self.padding } pub(crate) fn out_w(&self) -> usize { (self.i_w - 1) * self.stride + self.dilation * (self.k_w - 1) + self.output_padding + 1 - 2 * self.padding } pub(crate) fn out_dims(&self) -> Vec { vec![self.b_size, self.c_out, self.out_h(), self.out_w()] } } impl Tensor { fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result { let storage = self.storage() .conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?; let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D { arg, kernel, padding: params.padding, stride: params.stride, dilation: params.dilation, }); let out_dims = params.out_dims(); Ok(crate::tensor::from_storage(storage, out_dims, op, false)) } /// Applies a 1D convolution over the input tensor. pub fn conv1d( &self, kernel: &Self, padding: usize, stride: usize, dilation: usize, groups: usize, ) -> Result { let (c_out, c_in_k, k_size) = kernel.dims3()?; let (b_size, c_in, l_in) = self.dims3()?; if c_in != c_in_k * groups { Err(Error::Conv1dInvalidArgs { inp_shape: self.shape().clone(), k_shape: kernel.shape().clone(), padding, stride, msg: "the number of in-channels on the input doesn't match the kernel size", } .bt())? } let params = ParamsConv1D { b_size, l_in, c_out: c_out / groups, c_in: c_in / groups, k_size, padding, stride, dilation, }; if groups == 1 { self.conv1d_single_group(kernel, ¶ms) } else { let blocks = self.chunk(groups, 1)?; let kernel = kernel.chunk(groups, 0)?; let blocks = blocks .iter() .zip(&kernel) .map(|(block, kernel)| block.conv1d_single_group(kernel, ¶ms)) .collect::>>()?; Tensor::cat(&blocks, 1) } } fn conv_transpose1d_single_group( &self, kernel: &Self, params: &ParamsConvTranspose1D, ) -> Result { let storage = self.storage().conv_transpose1d( self.layout(), &kernel.storage(), kernel.layout(), params, )?; let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D { arg, kernel, padding: params.padding, output_padding: params.output_padding, stride: params.stride, dilation: params.dilation, }); let out_dims = params.out_dims(); Ok(crate::tensor::from_storage(storage, out_dims, op, false)) } /// Applies a 1D transposed convolution over the input tensor. pub fn conv_transpose1d( &self, kernel: &Self, padding: usize, output_padding: usize, stride: usize, dilation: usize, groups: usize, ) -> Result { let (c_in_k, c_out, k_size) = kernel.dims3()?; let (b_size, c_in, l_in) = self.dims3()?; if c_in != c_in_k { crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})") } if c_in % groups != 0 { crate::bail!("in_channel {c_in} is not divisible by the number of groups") } let params = ParamsConvTranspose1D { b_size, l_in, k_size, c_out, c_in: c_in / groups, padding, output_padding, stride, dilation, }; if groups == 1 { self.conv_transpose1d_single_group(kernel, ¶ms) } else { let blocks = self.chunk(groups, 1)?; let kernel = kernel.chunk(groups, 0)?; let blocks = blocks .iter() .zip(&kernel) .map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, ¶ms)) .collect::>>()?; Tensor::cat(&blocks, 1) } } fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result { let storage = self.storage() .conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?; let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D { arg, kernel, padding: params.padding, stride: params.stride, dilation: params.dilation, }); let out_dims = params.out_dims(); Ok(crate::tensor::from_storage(storage, out_dims, op, false)) } /// Applies a 2D convolution over the input tensor. pub fn conv2d( &self, kernel: &Self, padding: usize, stride: usize, dilation: usize, groups: usize, ) -> Result { let (b_size, c_in, i_h, i_w) = self.dims4()?; let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?; if c_in != c_in_k * groups { crate::bail!( "in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})" ) } let params = ParamsConv2D { b_size, i_h, i_w, k_h, k_w, c_out: c_out / groups, c_in: c_in / groups, padding, stride, dilation, cudnn_fwd_algo: None, }; if groups == 1 { self.conv2d_single_group(kernel, ¶ms) } else { let blocks = self.chunk(groups, 1)?; let kernel = kernel.chunk(groups, 0)?; let blocks = blocks .iter() .zip(&kernel) .map(|(block, kernel)| block.conv2d_single_group(kernel, ¶ms)) .collect::>>()?; Tensor::cat(&blocks, 1) } } /// Applies a 2D transposed convolution over the input tensor. pub fn conv_transpose2d( &self, kernel: &Self, padding: usize, output_padding: usize, stride: usize, dilation: usize, ) -> Result { let (b_size, c_in, i_h, i_w) = self.dims4()?; let (c_in_k, c_out, k_h, k_w) = kernel.dims4()?; if c_in != c_in_k { crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})") } let params = ParamsConvTranspose2D { b_size, i_h, i_w, k_h, k_w, c_out, c_in, padding, output_padding, stride, dilation, }; let storage = self.storage().conv_transpose2d( self.layout(), &kernel.storage(), kernel.layout(), ¶ms, )?; let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose2D { arg, kernel, padding: params.padding, output_padding: params.output_padding, stride: params.stride, dilation: params.dilation, }); let out_dims = params.out_dims(); Ok(crate::tensor::from_storage(storage, out_dims, op, false)) } }