summaryrefslogtreecommitdiff
path: root/candle-core/src/conv.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-09 21:27:03 +0200
committerGitHub <noreply@github.com>2023-08-09 20:27:03 +0100
commitfcfdcbd3373fb2fd744a0b4f7aa97cec7e620431 (patch)
treee0f7c1e5724a131369a9413151703004fb8fe4ae /candle-core/src/conv.rs
parent653ec5abc17c6c66f77b0d11542f7e5ea74f1912 (diff)
downloadcandle-fcfdcbd3373fb2fd744a0b4f7aa97cec7e620431.tar.gz
candle-fcfdcbd3373fb2fd744a0b4f7aa97cec7e620431.tar.bz2
candle-fcfdcbd3373fb2fd744a0b4f7aa97cec7e620431.zip
Add a conv1d benchmark based on the whisper sizes. (#377)
* Add a conv1d benchmark based on the whisper sizes. * Enforce the batch-dim in conv1d.
Diffstat (limited to 'candle-core/src/conv.rs')
-rw-r--r--candle-core/src/conv.rs7
1 files changed, 2 insertions, 5 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index 30799459..e3fea861 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -1,6 +1,6 @@
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv1D {
- pub(crate) b_size: Option<usize>,
+ pub(crate) b_size: usize,
// Maybe we should have a version without l_in as this bit depends on the input and not only on
// the weights.
pub(crate) l_in: usize,
@@ -19,10 +19,7 @@ impl ParamsConv1D {
pub(crate) fn out_dims(&self) -> Vec<usize> {
let l_out = self.l_out();
- match self.b_size {
- None => vec![self.c_out, l_out],
- Some(n) => vec![n, self.c_out, l_out],
- }
+ vec![self.b_size, self.c_out, l_out]
}
}