diff options
author | Matt <Rocketknight1@users.noreply.github.com> | 2023-08-10 00:19:20 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-10 00:19:20 +0100 |
commit | 0dc1e5f387f91ff86cc8a4c09d5668e8baaab1b3 (patch) | |
tree | 922a5f1387c42b6101c98749211d5763529453d1 /candle-core/src/tensor.rs | |
parent | 0cef3998fde542b9721215b77a80676a434b437f (diff) | |
parent | 25ec2d9f6bf36ff51c04f54f6c243828f6f4a8da (diff) | |
download | candle-0dc1e5f387f91ff86cc8a4c09d5668e8baaab1b3.tar.gz candle-0dc1e5f387f91ff86cc8a4c09d5668e8baaab1b3.tar.bz2 candle-0dc1e5f387f91ff86cc8a4c09d5668e8baaab1b3.zip |
Merge branch 'main' into readme_fixes
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 13 |
1 files changed, 1 insertions, 12 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index c94c0390..c14a4e39 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -773,18 +773,7 @@ impl Tensor { /// Applies a 1D convolution over the input tensor. pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> { let (c_out, c_in_k, k_size) = kernel.dims3()?; - let (b_size, c_in, l_in) = match *self.dims() { - [b_size, c_in, l_in] => (Some(b_size), c_in, l_in), - [c_in, l_in] => (None, c_in, l_in), - _ => Err(Error::Conv1dInvalidArgs { - inp_shape: self.shape().clone(), - k_shape: kernel.shape().clone(), - padding, - stride, - msg: "input rank is not 2 or 3", - } - .bt())?, - }; + let (b_size, c_in, l_in) = self.dims3()?; if c_in != c_in_k { Err(Error::Conv1dInvalidArgs { inp_shape: self.shape().clone(), |