diff options
author | Matt <Rocketknight1@users.noreply.github.com> | 2023-08-10 00:19:20 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-10 00:19:20 +0100 |
commit | 0dc1e5f387f91ff86cc8a4c09d5668e8baaab1b3 (patch) | |
tree | 922a5f1387c42b6101c98749211d5763529453d1 /candle-core/src | |
parent | 0cef3998fde542b9721215b77a80676a434b437f (diff) | |
parent | 25ec2d9f6bf36ff51c04f54f6c243828f6f4a8da (diff) | |
download | candle-0dc1e5f387f91ff86cc8a4c09d5668e8baaab1b3.tar.gz candle-0dc1e5f387f91ff86cc8a4c09d5668e8baaab1b3.tar.bz2 candle-0dc1e5f387f91ff86cc8a4c09d5668e8baaab1b3.zip |
Merge branch 'main' into readme_fixes
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/conv.rs | 7 | ||||
-rw-r--r-- | candle-core/src/cpu_backend.rs | 4 | ||||
-rw-r--r-- | candle-core/src/cuda_backend.rs | 2 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 13 |
4 files changed, 6 insertions, 20 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index 30799459..e3fea861 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -1,6 +1,6 @@ #[derive(Debug, Clone, PartialEq, Eq)] pub struct ParamsConv1D { - pub(crate) b_size: Option<usize>, + pub(crate) b_size: usize, // Maybe we should have a version without l_in as this bit depends on the input and not only on // the weights. pub(crate) l_in: usize, @@ -19,10 +19,7 @@ impl ParamsConv1D { pub(crate) fn out_dims(&self) -> Vec<usize> { let l_out = self.l_out(); - match self.b_size { - None => vec![self.c_out, l_out], - Some(n) => vec![n, self.c_out, l_out], - } + vec![self.b_size, self.c_out, l_out] } } diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 54f3f65b..238a9a69 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1037,10 +1037,10 @@ impl<'a> Map2 for Conv1D<'a> { let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?; let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?; let l_out = p.l_out(); - let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1); + let dst_elems = p.c_out * l_out * p.b_size; let mut dst = vec![T::zero(); dst_elems]; // The output shape is [b_size, c_out, l_out] - for b_idx in 0..p.b_size.unwrap_or(1) { + for b_idx in 0..p.b_size { let inp_idx = b_idx * inp_s0; let dst_idx = b_idx * p.c_out * l_out; for dst_c_idx in 0..p.c_out { diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index e51cc05d..a7f63353 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -904,7 +904,7 @@ impl<'a> Map2 for Conv1D<'a> { let dims = shape.dims(); let el = shape.elem_count(); let l_out = p.l_out(); - let dst_el = p.c_out * l_out * p.b_size.unwrap_or(1); + let dst_el = p.c_out * l_out * p.b_size; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::<T>("conv1d"), kernels::CONV)?; // SAFETY: Set later by running the kernel. diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index c94c0390..c14a4e39 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -773,18 +773,7 @@ impl Tensor { /// Applies a 1D convolution over the input tensor. pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> { let (c_out, c_in_k, k_size) = kernel.dims3()?; - let (b_size, c_in, l_in) = match *self.dims() { - [b_size, c_in, l_in] => (Some(b_size), c_in, l_in), - [c_in, l_in] => (None, c_in, l_in), - _ => Err(Error::Conv1dInvalidArgs { - inp_shape: self.shape().clone(), - k_shape: kernel.shape().clone(), - padding, - stride, - msg: "input rank is not 2 or 3", - } - .bt())?, - }; + let (b_size, c_in, l_in) = self.dims3()?; if c_in != c_in_k { Err(Error::Conv1dInvalidArgs { inp_shape: self.shape().clone(), |