diff options
Diffstat (limited to 'candle-examples/examples/whisper/main.rs')
-rw-r--r-- | candle-examples/examples/whisper/main.rs | 22 |
1 files changed, 17 insertions, 5 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index c9e9ccc6..dfe7a27f 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -1,4 +1,3 @@ -#![allow(dead_code)] // https://github.com/openai/whisper/blob/main/whisper/model.py // TODO: // - kv-cache support? @@ -31,9 +30,6 @@ const HOP_LENGTH: usize = 160; const CHUNK_LENGTH: usize = 30; const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input -const N_SAMPLES_PER_TOKEN: usize = HOP_LENGTH * 2; // the initial convolutions has stride 2 -const FRAMES_PER_SECOND: usize = SAMPLE_RATE / HOP_LENGTH; // 10ms per audio frame -const TOKENS_PER_SECOND: usize = SAMPLE_RATE / N_SAMPLES_PER_TOKEN; // 20ms per audio token const NO_SPEECH_THRESHOLD: f64 = 0.6; const LOGPROB_THRESHOLD: f64 = -1.0; @@ -44,7 +40,6 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; const SOT_TOKEN: u32 = 50257; const EOT_TOKEN: u32 = 50256; const NO_SPEECH_TOKEN: u32 = 50361; -const NO_TIMESTAMP_TOKEN: u32 = 50362; // From the _get_suppress_tokens function + 50362 (no timestamp) // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605 const SUPPRESS_TOKENS: [u32; 91] = [ @@ -56,6 +51,7 @@ const SUPPRESS_TOKENS: [u32; 91] = [ 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362, ]; +#[allow(dead_code)] #[derive(Debug, Clone)] struct DecodingResult { tokens: Vec<u32>, @@ -66,6 +62,7 @@ struct DecodingResult { compression_ratio: f64, } +#[allow(dead_code)] #[derive(Debug, Clone)] struct Segment { start: f64, @@ -243,10 +240,25 @@ struct Args { /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, } fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + let args = Args::parse(); + let _guard = if args.tracing { + println!("tracing..."); + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; let device = candle_examples::device(args.cpu)?; let default_model = "openai/whisper-tiny.en".to_string(); let path = std::path::PathBuf::from(default_model.clone()); |