summaryrefslogtreecommitdiff
path: root/candle-core/src/cuda_backend.rs
diff options
context:
space:
mode:
authorCiarĂ¡n Curley <ciaran.curley.11.11@gmail.com>2023-08-09 21:45:24 +0100
committerGitHub <noreply@github.com>2023-08-09 21:45:24 +0100
commit25ec2d9f6bf36ff51c04f54f6c243828f6f4a8da (patch)
treeb42df9c6c9270271091a0c3ec96331409d5ec672 /candle-core/src/cuda_backend.rs
parentda26e2832cfe12776748ff4239857fc54a2f5c82 (diff)
downloadcandle-25ec2d9f6bf36ff51c04f54f6c243828f6f4a8da.tar.gz
candle-25ec2d9f6bf36ff51c04f54f6c243828f6f4a8da.tar.bz2
candle-25ec2d9f6bf36ff51c04f54f6c243828f6f4a8da.zip
fix: remove incorrect unwrap (#379)
Diffstat (limited to 'candle-core/src/cuda_backend.rs')
-rw-r--r--candle-core/src/cuda_backend.rs2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index e51cc05d..a7f63353 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -904,7 +904,7 @@ impl<'a> Map2 for Conv1D<'a> {
let dims = shape.dims();
let el = shape.elem_count();
let l_out = p.l_out();
- let dst_el = p.c_out * l_out * p.b_size.unwrap_or(1);
+ let dst_el = p.c_out * l_out * p.b_size;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv1d"), kernels::CONV)?;
// SAFETY: Set later by running the kernel.