summaryrefslogtreecommitdiff
path: root/candle-examples/examples/whisper/main.rs
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2023-07-05 20:22:43 +0100
committerlaurent <laurent.mazare@gmail.com>2023-07-05 20:22:43 +0100
commit2c3d871b2e7490cb3740674647a03b0dcc8f67b6 (patch)
tree7c8d867001fbecec127ad9581056c0fd2f67f2a3 /candle-examples/examples/whisper/main.rs
parentb7388bbf718f9301b7e41e222654217f18e4c1e1 (diff)
downloadcandle-2c3d871b2e7490cb3740674647a03b0dcc8f67b6.tar.gz
candle-2c3d871b2e7490cb3740674647a03b0dcc8f67b6.tar.bz2
candle-2c3d871b2e7490cb3740674647a03b0dcc8f67b6.zip
Add a simpler way to specify the dim index for some ops.
Diffstat (limited to 'candle-examples/examples/whisper/main.rs')
-rw-r--r--candle-examples/examples/whisper/main.rs2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index 6ea3e536..fad3e91c 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -109,7 +109,7 @@ impl Decode {
};
tokens.push(next_token);
let prob = logits
- .softmax(logits.rank() - 1)?
+ .softmax(candle::D::Minus1)?
.get(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == EOT_TOKEN || tokens.len() > model.config.n_text_ctx {