summaryrefslogtreecommitdiff
path: root/candle-examples/examples/whisper/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/whisper/main.rs')
-rw-r--r--candle-examples/examples/whisper/main.rs28
1 files changed, 25 insertions, 3 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index dfe7a27f..d5f91053 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -11,7 +11,7 @@ extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
-use clap::Parser;
+use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer;
@@ -216,6 +216,23 @@ impl Decoder {
}
}
+#[derive(Clone, Copy, Debug, ValueEnum)]
+enum WhichModel {
+ Tiny,
+ Small,
+ Medium,
+}
+
+impl WhichModel {
+ fn model_and_revision(&self) -> (&'static str, &'static str) {
+ match self {
+ Self::Tiny => ("openai/whisper-tiny.en", "refs/pr/15"),
+ Self::Small => ("openai/whisper-small.en", "refs/pr/10"),
+ Self::Medium => ("openai/whisper-medium.en", "refs/pr/11"),
+ }
+ }
+}
+
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@@ -231,6 +248,10 @@ struct Args {
#[arg(long)]
revision: Option<String>,
+ /// The model to be used, can be tiny, small, medium.
+ #[arg(long, default_value = "tiny")]
+ model: WhichModel,
+
/// The input to be processed, in wav format, will default to `jfk.wav`. Alternatively
/// this can be set to sample:jfk, sample:gb1, ... to fetch a sample from the following
/// repo: https://huggingface.co/datasets/Narsil/candle_demo/
@@ -260,9 +281,10 @@ fn main() -> Result<()> {
None
};
let device = candle_examples::device(args.cpu)?;
- let default_model = "openai/whisper-tiny.en".to_string();
+ let (default_model, default_revision) = args.model.model_and_revision();
+ let default_model = default_model.to_string();
+ let default_revision = default_revision.to_string();
let path = std::path::PathBuf::from(default_model.clone());
- let default_revision = "refs/pr/15".to_string();
let (model_id, revision) = match (args.model_id, args.revision) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),