summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs13
1 files changed, 1 insertions, 12 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index c94c0390..c14a4e39 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -773,18 +773,7 @@ impl Tensor {
/// Applies a 1D convolution over the input tensor.
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.dims3()?;
- let (b_size, c_in, l_in) = match *self.dims() {
- [b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
- [c_in, l_in] => (None, c_in, l_in),
- _ => Err(Error::Conv1dInvalidArgs {
- inp_shape: self.shape().clone(),
- k_shape: kernel.shape().clone(),
- padding,
- stride,
- msg: "input rank is not 2 or 3",
- }
- .bt())?,
- };
+ let (b_size, c_in, l_in) = self.dims3()?;
if c_in != c_in_k {
Err(Error::Conv1dInvalidArgs {
inp_shape: self.shape().clone(),