diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-23 12:58:55 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-23 12:58:55 +0100 |
commit | aba1e90797e430f28eec13b14b76dd5355876f9c (patch) | |
tree | 16bcf7fb151715d3bcdbec2b5263922bd0bdd35a /candle-core/src/tensor.rs | |
parent | 4ee1cf038ada55ec477dcd6496cf2aec1902775b (diff) | |
download | candle-aba1e90797e430f28eec13b14b76dd5355876f9c.tar.gz candle-aba1e90797e430f28eec13b14b76dd5355876f9c.tar.bz2 candle-aba1e90797e430f28eec13b14b76dd5355876f9c.zip |
Add some group parameter to convolutions. (#566)
* Add some group parameter to convolutions.
* Avoid some unnecessary groups checks.
* Move the tensor convolution bits.
* Properh handling of groups.
* Bump the crate version.
* And add a changelog.
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 70 |
1 files changed, 2 insertions, 68 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index a4b9795b..46f9c53f 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -124,7 +124,7 @@ macro_rules! broadcast_binary_op { } /// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides. -fn from_storage<S: Into<Shape>>( +pub(crate) fn from_storage<S: Into<Shape>>( storage: Storage, shape: S, op: BackpropOp, @@ -787,72 +787,6 @@ impl Tensor { self.cmp(rhs, CmpOp::Le) } - /// Applies a 1D convolution over the input tensor. - pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> { - let (c_out, c_in_k, k_size) = kernel.dims3()?; - let (b_size, c_in, l_in) = self.dims3()?; - if c_in != c_in_k { - Err(Error::Conv1dInvalidArgs { - inp_shape: self.shape().clone(), - k_shape: kernel.shape().clone(), - padding, - stride, - msg: "the number of in-channels on the input doesn't match the kernel size", - } - .bt())? - } - let params = crate::conv::ParamsConv1D { - b_size, - l_in, - c_out, - c_in, - k_size, - padding, - stride, - }; - let storage = - self.storage() - .conv1d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?; - let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D { - arg, - kernel, - padding, - stride, - }); - let out_dims = params.out_dims(); - Ok(from_storage(storage, out_dims, op, false)) - } - - pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> { - let (b_size, c_in, i_h, i_w) = self.dims4()?; - let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?; - if c_in != c_in_k { - crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})") - } - let params = crate::conv::ParamsConv2D { - b_size, - i_h, - i_w, - k_h, - k_w, - c_out, - c_in, - padding, - stride, - }; - let storage = - self.storage() - .conv2d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?; - let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D { - arg, - kernel, - padding, - stride, - }); - let out_dims = params.out_dims(); - Ok(from_storage(storage, out_dims, op, false)) - } - pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> { let (n, c, _h, _w) = self.dims4()?; let op = BackpropOp::new1(self, Op::UpsampleNearest2D); @@ -1920,7 +1854,7 @@ impl Tensor { } } - fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { + pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { self.storage.read().unwrap() } |