diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/conv.rs | 112 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 70 |
2 files changed, 114 insertions, 68 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index e3fea861..d4b7a76d 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -1,3 +1,5 @@ +use crate::{op::BackpropOp, op::Op, Error, Result, Tensor}; + #[derive(Debug, Clone, PartialEq, Eq)] pub struct ParamsConv1D { pub(crate) b_size: usize, @@ -51,3 +53,113 @@ impl ParamsConv2D { vec![self.b_size, self.c_out, self.out_h(), self.out_w()] } } + +impl Tensor { + fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result<Self> { + let storage = + self.storage() + .conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?; + let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D { + arg, + kernel, + padding: params.padding, + stride: params.stride, + }); + let out_dims = params.out_dims(); + Ok(crate::tensor::from_storage(storage, out_dims, op, false)) + } + + /// Applies a 1D convolution over the input tensor. + pub fn conv1d( + &self, + kernel: &Self, + padding: usize, + stride: usize, + groups: usize, + ) -> Result<Self> { + let (c_out, c_in_k, k_size) = kernel.dims3()?; + let (b_size, c_in, l_in) = self.dims3()?; + if c_in != c_in_k * groups { + Err(Error::Conv1dInvalidArgs { + inp_shape: self.shape().clone(), + k_shape: kernel.shape().clone(), + padding, + stride, + msg: "the number of in-channels on the input doesn't match the kernel size", + } + .bt())? + } + + let params = ParamsConv1D { + b_size, + l_in, + c_out, + c_in, + k_size, + padding, + stride, + }; + if groups == 1 { + self.conv1d_single_group(kernel, ¶ms) + } else { + let blocks = self.chunk(groups, 1)?; + let blocks = blocks + .iter() + .map(|block| block.conv1d_single_group(kernel, ¶ms)) + .collect::<Result<Vec<_>>>()?; + Tensor::cat(&blocks, 1) + } + } + + fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> { + let storage = + self.storage() + .conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?; + let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D { + arg, + kernel, + padding: params.padding, + stride: params.stride, + }); + let out_dims = params.out_dims(); + Ok(crate::tensor::from_storage(storage, out_dims, op, false)) + } + + /// Applies a 2D convolution over the input tensor. + pub fn conv2d( + &self, + kernel: &Self, + padding: usize, + stride: usize, + groups: usize, + ) -> Result<Self> { + let (b_size, c_in, i_h, i_w) = self.dims4()?; + let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?; + if c_in != c_in_k * groups { + crate::bail!( + "in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})" + ) + } + let params = ParamsConv2D { + b_size, + i_h, + i_w, + k_h, + k_w, + c_out, + c_in, + padding, + stride, + }; + if groups == 1 { + self.conv2d_single_group(kernel, ¶ms) + } else { + let blocks = self.chunk(groups, 1)?; + let blocks = blocks + .iter() + .map(|block| block.conv2d_single_group(kernel, ¶ms)) + .collect::<Result<Vec<_>>>()?; + Tensor::cat(&blocks, 1) + } + } +} diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index a4b9795b..46f9c53f 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -124,7 +124,7 @@ macro_rules! broadcast_binary_op { } /// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides. -fn from_storage<S: Into<Shape>>( +pub(crate) fn from_storage<S: Into<Shape>>( storage: Storage, shape: S, op: BackpropOp, @@ -787,72 +787,6 @@ impl Tensor { self.cmp(rhs, CmpOp::Le) } - /// Applies a 1D convolution over the input tensor. - pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> { - let (c_out, c_in_k, k_size) = kernel.dims3()?; - let (b_size, c_in, l_in) = self.dims3()?; - if c_in != c_in_k { - Err(Error::Conv1dInvalidArgs { - inp_shape: self.shape().clone(), - k_shape: kernel.shape().clone(), - padding, - stride, - msg: "the number of in-channels on the input doesn't match the kernel size", - } - .bt())? - } - let params = crate::conv::ParamsConv1D { - b_size, - l_in, - c_out, - c_in, - k_size, - padding, - stride, - }; - let storage = - self.storage() - .conv1d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?; - let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D { - arg, - kernel, - padding, - stride, - }); - let out_dims = params.out_dims(); - Ok(from_storage(storage, out_dims, op, false)) - } - - pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> { - let (b_size, c_in, i_h, i_w) = self.dims4()?; - let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?; - if c_in != c_in_k { - crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})") - } - let params = crate::conv::ParamsConv2D { - b_size, - i_h, - i_w, - k_h, - k_w, - c_out, - c_in, - padding, - stride, - }; - let storage = - self.storage() - .conv2d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?; - let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D { - arg, - kernel, - padding, - stride, - }); - let out_dims = params.out_dims(); - Ok(from_storage(storage, out_dims, op, false)) - } - pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> { let (n, c, _h, _w) = self.dims4()?; let op = BackpropOp::new1(self, Op::UpsampleNearest2D); @@ -1920,7 +1854,7 @@ impl Tensor { } } - fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { + pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { self.storage.read().unwrap() } |