summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-05 13:56:41 +0100
committerGitHub <noreply@github.com>2023-07-05 13:56:41 +0100
commit4e80319147510e7fe3a89724a2e66f98f1e4b974 (patch)
tree308ac12d16a70ca7e025652efdf3dffe8a118c1e
parente4fb8c45cc7a30de4aaf365ebc1221a53a4db157 (diff)
parentbae6d07b7eee229b5707d6f65dd53eb40af17e83 (diff)
downloadcandle-4e80319147510e7fe3a89724a2e66f98f1e4b974.tar.gz
candle-4e80319147510e7fe3a89724a2e66f98f1e4b974.tar.bz2
candle-4e80319147510e7fe3a89724a2e66f98f1e4b974.zip
Merge pull request #77 from LaurentMazare/whisper-fix-emb
[whisper] Fix the position embeddings size.
-rw-r--r--candle-examples/examples/whisper/model.rs4
1 files changed, 3 insertions, 1 deletions
diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs
index 53ee6a90..bf322c51 100644
--- a/candle-examples/examples/whisper/model.rs
+++ b/candle-examples/examples/whisper/model.rs
@@ -458,7 +458,9 @@ impl AudioEncoder {
let x = self.conv1.forward(x)?.gelu()?;
let x = self.conv2.forward(&x)?.gelu()?;
let x = x.transpose(1, 2)?;
- let mut x = x.broadcast_add(&self.positional_embedding)?;
+ let (_bsize, seq_len, _hidden) = x.shape().r3()?;
+ let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
+ let mut x = x.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter() {
x = block.forward(&x, None, None)?
}