summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-10-01 00:24:17 +0200
committerGitHub <noreply@github.com>2024-10-01 00:24:17 +0200
commit6110ad8d4ff8272bdd10687eae4edee59a07b517 (patch)
treea8ef306b9e9316078a10768b369ebaa04cadd981
parentaa35bf2ff5edd9c3534fd7744b333a1abaed4406 (diff)
downloadcandle-6110ad8d4ff8272bdd10687eae4edee59a07b517.tar.gz
candle-6110ad8d4ff8272bdd10687eae4edee59a07b517.tar.bz2
candle-6110ad8d4ff8272bdd10687eae4edee59a07b517.zip
Refactor the whisper microphone example. (#2523)
* Refactor the whisper microphone example. * Tweak the whisper microphone example more.
-rw-r--r--candle-examples/Cargo.toml2
-rw-r--r--candle-examples/examples/whisper-microphone/main.rs154
2 files changed, 74 insertions, 82 deletions
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml
index 543c9666..2c96f87d 100644
--- a/candle-examples/Cargo.toml
+++ b/candle-examples/Cargo.toml
@@ -65,7 +65,7 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
-microphone = ["cpal"]
+microphone = ["cpal", "rubato"]
encodec = ["cpal", "symphonia", "rubato"]
mimi = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]
diff --git a/candle-examples/examples/whisper-microphone/main.rs b/candle-examples/examples/whisper-microphone/main.rs
index 9f7d5b82..44a64b05 100644
--- a/candle-examples/examples/whisper-microphone/main.rs
+++ b/candle-examples/examples/whisper-microphone/main.rs
@@ -10,7 +10,6 @@ use candle_nn::{ops::softmax, VarBuilder};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
-use std::iter;
use tokenizers::Tokenizer;
mod multilingual;
@@ -18,7 +17,6 @@ mod multilingual;
use candle_transformers::models::whisper::{self as m, audio, Config};
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
-use std::sync::{Arc, Mutex};
pub enum Model {
Normal(m::model::Whisper),
@@ -479,6 +477,10 @@ struct Args {
/// Print the full DecodingResult structure rather than just the text.
#[arg(long)]
verbose: bool,
+
+ /// The input device to use.
+ #[arg(long)]
+ device: Option<String>,
}
pub fn main() -> Result<()> {
@@ -543,13 +545,12 @@ pub fn main() -> Result<()> {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
Model::Normal(m::model::Whisper::load(&vb, config.clone())?)
};
- let language_token = None;
- let mut dc = Decoder::new(
+ let mut decoder = Decoder::new(
model,
tokenizer.clone(),
args.seed,
&device,
- language_token,
+ /* language_token */ None,
args.task,
args.timestamps,
args.verbose,
@@ -565,47 +566,69 @@ pub fn main() -> Result<()> {
// Set up the input device and stream with the default input config.
let host = cpal::default_host();
- let _device = "default";
- let _device = if _device == "default" {
- host.default_input_device()
- } else {
- host.input_devices()?
- .find(|x| x.name().map(|y| y == _device).unwrap_or(false))
+ let audio_device = match args.device.as_ref() {
+ None => host.default_input_device(),
+ Some(device) => host
+ .input_devices()?
+ .find(|x| x.name().map_or(false, |y| &y == device)),
}
- .expect("failed to find input device");
+ .expect("failed to find the audio input device");
- let _config = _device
+ let audio_config = audio_device
.default_input_config()
.expect("Failed to get default input config");
-
- let channel_count = _config.channels() as usize;
-
- let audio_ring_buffer = Arc::new(Mutex::new(Vec::new()));
- let audio_ring_buffer_2 = audio_ring_buffer.clone();
-
- std::thread::spawn(move || loop {
- let data = record_audio(&_device, &_config, 300).unwrap();
- audio_ring_buffer.lock().unwrap().extend_from_slice(&data);
- let max_len = data.len() * 16;
- let data_len = data.len();
- let len = audio_ring_buffer.lock().unwrap().len();
- if len > max_len {
- let mut data = audio_ring_buffer.lock().unwrap();
- let new_data = data[data_len..].to_vec();
- *data = new_data;
- }
- });
+ println!("audio config {audio_config:?}");
+
+ let channel_count = audio_config.channels() as usize;
+ let in_sample_rate = audio_config.sample_rate().0 as usize;
+ let resample_ratio = 16000. / in_sample_rate as f64;
+ let mut resampler = rubato::FastFixedIn::new(
+ resample_ratio,
+ 10.,
+ rubato::PolynomialDegree::Septic,
+ 1024,
+ 1,
+ )?;
+ let (tx, rx) = std::sync::mpsc::channel();
+ let stream = audio_device.build_input_stream(
+ &audio_config.config(),
+ move |pcm: &[f32], _: &cpal::InputCallbackInfo| {
+ let pcm = pcm
+ .iter()
+ .step_by(channel_count)
+ .copied()
+ .collect::<Vec<f32>>();
+ if !pcm.is_empty() {
+ tx.send(pcm).unwrap()
+ }
+ },
+ move |err| {
+ eprintln!("an error occurred on stream: {}", err);
+ },
+ None,
+ )?;
+ stream.play()?;
// loop to process the audio data forever (until the user stops the program)
- println!("Transcribing audio...");
- for (i, _) in iter::repeat(()).enumerate() {
- std::thread::sleep(std::time::Duration::from_millis(1000));
- let data = audio_ring_buffer_2.lock().unwrap().clone();
- let pcm_data: Vec<_> = data[..data.len() / channel_count as usize]
- .iter()
- .map(|v| *v as f32 / 32768.)
- .collect();
- let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
+ println!("transcribing audio...");
+ let mut buffered_pcm = vec![];
+ let mut language_token_set = false;
+ while let Ok(pcm) = rx.recv() {
+ use rubato::Resampler;
+
+ buffered_pcm.extend_from_slice(&pcm);
+ if buffered_pcm.len() < 10 * in_sample_rate {
+ continue;
+ }
+ let mut resampled_pcm = vec![];
+ for buffered_pcm in buffered_pcm.chunks(1024) {
+ let pcm = resampler.process(&[&buffered_pcm], None)?;
+ resampled_pcm.extend_from_slice(&pcm[0])
+ }
+ let pcm = resampled_pcm;
+ println!("{} {}", buffered_pcm.len(), pcm.len());
+ buffered_pcm.clear();
+ let mel = audio::pcm_to_mel(&config, &pcm, &mel_filters);
let mel_len = mel.len();
let mel = Tensor::from_vec(
mel,
@@ -614,9 +637,13 @@ pub fn main() -> Result<()> {
)?;
// on the first iteration, we detect the language and set the language token.
- if i == 0 {
+ if !language_token_set {
let language_token = match (args.model.is_multilingual(), args.language.clone()) {
- (true, None) => Some(multilingual::detect_language(dc.model(), &tokenizer, &mel)?),
+ (true, None) => Some(multilingual::detect_language(
+ decoder.model(),
+ &tokenizer,
+ &mel,
+ )?),
(false, None) => None,
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
Ok(token_id) => Some(token_id),
@@ -627,47 +654,12 @@ pub fn main() -> Result<()> {
}
};
println!("language_token: {:?}", language_token);
- dc.set_language_token(language_token);
+ decoder.set_language_token(language_token);
+ language_token_set = true;
}
- dc.run(
- &mel,
- Some((
- i as f64,
- i as f64 + data.len() as f64 / m::SAMPLE_RATE as f64,
- )),
- )?;
- dc.reset_kv_cache();
+ decoder.run(&mel, None)?;
+ decoder.reset_kv_cache();
}
Ok(())
}
-
-fn record_audio(
- device: &cpal::Device,
- config: &cpal::SupportedStreamConfig,
- milliseconds: u64,
-) -> Result<Vec<i16>> {
- let writer = Arc::new(Mutex::new(Vec::new()));
- let writer_2 = writer.clone();
- let stream = device.build_input_stream(
- &config.config(),
- move |data: &[f32], _: &cpal::InputCallbackInfo| {
- let processed = data
- .iter()
- .map(|v| (v * 32768.0) as i16)
- .collect::<Vec<i16>>();
- writer_2.lock().unwrap().extend_from_slice(&processed);
- },
- move |err| {
- eprintln!("an error occurred on stream: {}", err);
- },
- None,
- )?;
- stream.play()?;
- std::thread::sleep(std::time::Duration::from_millis(milliseconds));
- drop(stream);
- let data = writer.lock().unwrap().clone();
- let step = 3;
- let data: Vec<i16> = data.iter().step_by(step).copied().collect();
- Ok(data)
-}