summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/quantized/main.rs2
-rw-r--r--candle-examples/examples/whisper/main.rs10
2 files changed, 9 insertions, 3 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index b9b2ec9a..f0ae8973 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -417,7 +417,7 @@ fn main() -> anyhow::Result<()> {
}
let dt = start_gen.elapsed();
println!(
- "\n\n{} tokens generated ({} token/s)\n",
+ "\n\n{} tokens generated ({:.2} token/s)\n",
token_generated,
token_generated as f64 / dt.as_secs_f64(),
);
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index 5c58c002..9f8810a7 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -243,10 +243,15 @@ pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
#[derive(Clone, Copy, Debug, ValueEnum)]
enum WhichModel {
Tiny,
+ #[value(name = "tiny.en")]
TinyEn,
Base,
+ #[value(name = "base.en")]
BaseEn,
+ Small,
+ #[value(name = "small.en")]
SmallEn,
+ #[value(name = "medium.en")]
MediumEn,
LargeV2,
}
@@ -254,7 +259,7 @@ enum WhichModel {
impl WhichModel {
fn is_multilingual(&self) -> bool {
match self {
- Self::Tiny | Self::Base | Self::LargeV2 => true,
+ Self::Tiny | Self::Base | Self::Small | Self::LargeV2 => true,
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
}
}
@@ -264,6 +269,7 @@ impl WhichModel {
Self::TinyEn => ("openai/whisper-tiny.en", "refs/pr/15"),
Self::Base => ("openai/whisper-base", "refs/pr/22"),
Self::BaseEn => ("openai/whisper-base.en", "refs/pr/13"),
+ Self::Small => ("openai/whisper-small", "main"),
Self::SmallEn => ("openai/whisper-small.en", "refs/pr/10"),
Self::MediumEn => ("openai/whisper-medium.en", "refs/pr/11"),
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
@@ -287,7 +293,7 @@ struct Args {
revision: Option<String>,
/// The model to be used, can be tiny, small, medium.
- #[arg(long, default_value = "tiny-en")]
+ #[arg(long, default_value = "tiny.en")]
model: WhichModel,
/// The input to be processed, in wav format, will default to `jfk.wav`. Alternatively