summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/conv.rs7
-rw-r--r--candle-core/src/cpu_backend.rs4
-rw-r--r--candle-core/src/cuda_backend.rs2
-rw-r--r--candle-core/src/tensor.rs13
4 files changed, 6 insertions, 20 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index 30799459..e3fea861 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -1,6 +1,6 @@
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv1D {
- pub(crate) b_size: Option<usize>,
+ pub(crate) b_size: usize,
// Maybe we should have a version without l_in as this bit depends on the input and not only on
// the weights.
pub(crate) l_in: usize,
@@ -19,10 +19,7 @@ impl ParamsConv1D {
pub(crate) fn out_dims(&self) -> Vec<usize> {
let l_out = self.l_out();
- match self.b_size {
- None => vec![self.c_out, l_out],
- Some(n) => vec![n, self.c_out, l_out],
- }
+ vec![self.b_size, self.c_out, l_out]
}
}
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 54f3f65b..238a9a69 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -1037,10 +1037,10 @@ impl<'a> Map2 for Conv1D<'a> {
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
let l_out = p.l_out();
- let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1);
+ let dst_elems = p.c_out * l_out * p.b_size;
let mut dst = vec![T::zero(); dst_elems];
// The output shape is [b_size, c_out, l_out]
- for b_idx in 0..p.b_size.unwrap_or(1) {
+ for b_idx in 0..p.b_size {
let inp_idx = b_idx * inp_s0;
let dst_idx = b_idx * p.c_out * l_out;
for dst_c_idx in 0..p.c_out {
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index e51cc05d..a7f63353 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -904,7 +904,7 @@ impl<'a> Map2 for Conv1D<'a> {
let dims = shape.dims();
let el = shape.elem_count();
let l_out = p.l_out();
- let dst_el = p.c_out * l_out * p.b_size.unwrap_or(1);
+ let dst_el = p.c_out * l_out * p.b_size;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv1d"), kernels::CONV)?;
// SAFETY: Set later by running the kernel.
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index c94c0390..c14a4e39 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -773,18 +773,7 @@ impl Tensor {
/// Applies a 1D convolution over the input tensor.
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.dims3()?;
- let (b_size, c_in, l_in) = match *self.dims() {
- [b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
- [c_in, l_in] => (None, c_in, l_in),
- _ => Err(Error::Conv1dInvalidArgs {
- inp_shape: self.shape().clone(),
- k_shape: kernel.shape().clone(),
- padding,
- stride,
- msg: "input rank is not 2 or 3",
- }
- .bt())?,
- };
+ let (b_size, c_in, l_in) = self.dims3()?;
if c_in != c_in_k {
Err(Error::Conv1dInvalidArgs {
inp_shape: self.shape().clone(),