summaryrefslogtreecommitdiff
path: root/candle-examples/examples/mimi
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-09-20 14:31:20 -0600
committerGitHub <noreply@github.com>2024-09-20 14:31:20 -0600
commitc58c5d5b01b1457997ac68b3a873b64ca98afcb6 (patch)
tree661ead8f3d8fe70c7a55170899f2f97f6c02f132 /candle-examples/examples/mimi
parent382c6b51af46f3906a191435d198f404f66de95e (diff)
downloadcandle-c58c5d5b01b1457997ac68b3a873b64ca98afcb6.tar.gz
candle-c58c5d5b01b1457997ac68b3a873b64ca98afcb6.tar.bz2
candle-c58c5d5b01b1457997ac68b3a873b64ca98afcb6.zip
Add the mimi audio-tokenizer. (#2488)
* Add the mimi audio-tokenizer. * Formatting tweaks. * Add a full example. * Use the transformers names. * More renamings. * Get encoding and decoding to work. * Clippy fixes.
Diffstat (limited to 'candle-examples/examples/mimi')
-rw-r--r--candle-examples/examples/mimi/README.md20
-rw-r--r--candle-examples/examples/mimi/audio_io.rs275
-rw-r--r--candle-examples/examples/mimi/main.rs131
3 files changed, 426 insertions, 0 deletions
diff --git a/candle-examples/examples/mimi/README.md b/candle-examples/examples/mimi/README.md
new file mode 100644
index 00000000..bbcfcdb7
--- /dev/null
+++ b/candle-examples/examples/mimi/README.md
@@ -0,0 +1,20 @@
+# candle-mimi
+
+[Mimi](https://huggingface.co/kyutai/mimi) is a state of the art audio
+compression model using an encoder/decoder architecture with residual vector
+quantization. The candle implementation supports streaming meaning that it's
+possible to encode or decode a stream of audio tokens on the flight to provide
+low latency interaction with an audio model.
+
+## Running one example
+
+Generating some audio tokens from an audio files.
+```bash
+wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
+cargo run --example mimi --features mimi --release -- audio-to-code bria.mp3 bria.safetensors
+```
+
+And decoding the audio tokens back into a sound file.
+```bash
+cargo run --example mimi --features mimi --release -- code-to-audio bria.safetensors bria.wav
+```
diff --git a/candle-examples/examples/mimi/audio_io.rs b/candle-examples/examples/mimi/audio_io.rs
new file mode 100644
index 00000000..2103dd4a
--- /dev/null
+++ b/candle-examples/examples/mimi/audio_io.rs
@@ -0,0 +1,275 @@
+#![allow(unused)]
+use anyhow::{Context, Result};
+use std::sync::{Arc, Mutex};
+
+pub const SAMPLE_RATE: usize = 24_000;
+
+pub(crate) struct AudioOutputData_ {
+ resampled_data: std::collections::VecDeque<f32>,
+ resampler: rubato::FastFixedIn<f32>,
+ output_buffer: Vec<f32>,
+ input_buffer: Vec<f32>,
+ input_len: usize,
+}
+
+impl AudioOutputData_ {
+ pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
+ use rubato::Resampler;
+
+ let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
+ let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
+ let resampler = rubato::FastFixedIn::new(
+ resample_ratio,
+ f64::max(resample_ratio, 1.0),
+ rubato::PolynomialDegree::Septic,
+ 1024,
+ 1,
+ )?;
+ let input_buffer = resampler.input_buffer_allocate(true).remove(0);
+ let output_buffer = resampler.output_buffer_allocate(true).remove(0);
+ Ok(Self {
+ resampled_data,
+ resampler,
+ input_buffer,
+ output_buffer,
+ input_len: 0,
+ })
+ }
+
+ pub fn reset(&mut self) {
+ use rubato::Resampler;
+ self.output_buffer.fill(0.);
+ self.input_buffer.fill(0.);
+ self.resampler.reset();
+ self.resampled_data.clear();
+ }
+
+ pub(crate) fn take_all(&mut self) -> Vec<f32> {
+ let mut data = Vec::with_capacity(self.resampled_data.len());
+ while let Some(elem) = self.resampled_data.pop_back() {
+ data.push(elem);
+ }
+ data
+ }
+
+ pub(crate) fn is_empty(&self) -> bool {
+ self.resampled_data.is_empty()
+ }
+
+ // Assumes that the input buffer is large enough.
+ fn push_input_buffer(&mut self, samples: &[f32]) {
+ self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
+ self.input_len += samples.len()
+ }
+
+ pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
+ use rubato::Resampler;
+
+ let mut pos_in = 0;
+ loop {
+ let rem = self.input_buffer.len() - self.input_len;
+ let pos_end = usize::min(pos_in + rem, samples.len());
+ self.push_input_buffer(&samples[pos_in..pos_end]);
+ pos_in = pos_end;
+ if self.input_len < self.input_buffer.len() {
+ break;
+ }
+ let (_, out_len) = self.resampler.process_into_buffer(
+ &[&self.input_buffer],
+ &mut [&mut self.output_buffer],
+ None,
+ )?;
+ for &elem in self.output_buffer[..out_len].iter() {
+ self.resampled_data.push_front(elem)
+ }
+ self.input_len = 0;
+ }
+ Ok(())
+ }
+}
+
+type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
+
+pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
+ use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
+
+ println!("Setup audio output stream!");
+ let host = cpal::default_host();
+ let device = host
+ .default_output_device()
+ .context("no output device available")?;
+ let mut supported_configs_range = device.supported_output_configs()?;
+ let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
+ // On macOS, it's commonly the case that there are only stereo outputs.
+ None => device
+ .supported_output_configs()?
+ .next()
+ .context("no audio output available")?,
+ Some(config_range) => config_range,
+ };
+ let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
+ config_range.min_sample_rate(),
+ config_range.max_sample_rate(),
+ );
+ let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
+ let channels = config.channels as usize;
+ println!(
+ "cpal device: {} {} {config:?}",
+ device.name().unwrap_or_else(|_| "unk".to_string()),
+ config.sample_rate.0
+ );
+ let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
+ SAMPLE_RATE,
+ config.sample_rate.0 as usize,
+ )?));
+ let ad = audio_data.clone();
+ let stream = device.build_output_stream(
+ &config,
+ move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
+ data.fill(0.);
+ let mut ad = ad.lock().unwrap();
+ let mut last_elem = 0f32;
+ for (idx, elem) in data.iter_mut().enumerate() {
+ if idx % channels == 0 {
+ match ad.resampled_data.pop_back() {
+ None => break,
+ Some(v) => {
+ last_elem = v;
+ *elem = v
+ }
+ }
+ } else {
+ *elem = last_elem
+ }
+ }
+ },
+ move |err| eprintln!("cpal error: {err}"),
+ None, // None=blocking, Some(Duration)=timeout
+ )?;
+ stream.play()?;
+ Ok((stream, audio_data))
+}
+
+pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
+ use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
+
+ println!("Setup audio input stream!");
+ let host = cpal::default_host();
+ let device = host
+ .default_input_device()
+ .context("no input device available")?;
+ let mut supported_configs_range = device.supported_input_configs()?;
+ let config_range = supported_configs_range
+ .find(|c| c.channels() == 1)
+ .context("no audio input available")?;
+ let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
+ config_range.min_sample_rate(),
+ config_range.max_sample_rate(),
+ );
+ let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
+ println!(
+ "cpal device: {} {} {config:?}",
+ device.name().unwrap_or_else(|_| "unk".to_string()),
+ config.sample_rate.0
+ );
+ let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
+ config.sample_rate.0 as usize,
+ SAMPLE_RATE,
+ )?));
+ let ad = audio_data.clone();
+ let stream = device.build_input_stream(
+ &config,
+ move |data: &[f32], _: &cpal::InputCallbackInfo| {
+ let mut ad = ad.lock().unwrap();
+ if let Err(err) = ad.push_samples(data) {
+ eprintln!("error processing audio input {err:?}")
+ }
+ },
+ move |err| eprintln!("cpal error: {err}"),
+ None, // None=blocking, Some(Duration)=timeout
+ )?;
+ stream.play()?;
+ Ok((stream, audio_data))
+}
+
+fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
+where
+ T: symphonia::core::sample::Sample,
+ f32: symphonia::core::conv::FromSample<T>,
+{
+ use symphonia::core::audio::Signal;
+ use symphonia::core::conv::FromSample;
+ samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
+}
+
+pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
+ use symphonia::core::audio::{AudioBufferRef, Signal};
+
+ let src = std::fs::File::open(path)?;
+ let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
+ let hint = symphonia::core::probe::Hint::new();
+ let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
+ let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
+ let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
+ let mut format = probed.format;
+ let track = format
+ .tracks()
+ .iter()
+ .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
+ .expect("no supported audio tracks");
+ let mut decoder = symphonia::default::get_codecs()
+ .make(&track.codec_params, &Default::default())
+ .expect("unsupported codec");
+ let track_id = track.id;
+ let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
+ let mut pcm_data = Vec::new();
+ while let Ok(packet) = format.next_packet() {
+ while !format.metadata().is_latest() {
+ format.metadata().pop();
+ }
+ if packet.track_id() != track_id {
+ continue;
+ }
+ match decoder.decode(&packet)? {
+ AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
+ AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
+ AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
+ AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
+ AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
+ AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
+ AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
+ AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
+ AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
+ AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
+ }
+ }
+ Ok((pcm_data, sample_rate))
+}
+
+pub(crate) fn resample(pcm_in: &[f32], sr_in: usize, sr_out: usize) -> Result<Vec<f32>> {
+ use rubato::Resampler;
+
+ let mut pcm_out =
+ Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
+
+ let mut resampler = rubato::FftFixedInOut::<f32>::new(sr_in, sr_out, 1024, 1)?;
+ let mut output_buffer = resampler.output_buffer_allocate(true);
+ let mut pos_in = 0;
+ while pos_in + resampler.input_frames_next() < pcm_in.len() {
+ let (in_len, out_len) =
+ resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
+ pos_in += in_len;
+ pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
+ }
+
+ if pos_in < pcm_in.len() {
+ let (_in_len, out_len) = resampler.process_partial_into_buffer(
+ Some(&[&pcm_in[pos_in..]]),
+ &mut output_buffer,
+ None,
+ )?;
+ pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
+ }
+
+ Ok(pcm_out)
+}
diff --git a/candle-examples/examples/mimi/main.rs b/candle-examples/examples/mimi/main.rs
new file mode 100644
index 00000000..cfc1a553
--- /dev/null
+++ b/candle-examples/examples/mimi/main.rs
@@ -0,0 +1,131 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use anyhow::Result;
+use candle::{DType, IndexOp, Tensor};
+use candle_nn::VarBuilder;
+use candle_transformers::models::mimi::{Config, Model};
+use clap::{Parser, ValueEnum};
+use hf_hub::api::sync::Api;
+
+mod audio_io;
+
+#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
+enum Action {
+ AudioToAudio,
+ AudioToCode,
+ CodeToAudio,
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// The action to be performed, specifies the format for the input and output data.
+ action: Action,
+
+ /// The input file, either an audio file or some mimi tokens stored as safetensors.
+ in_file: String,
+
+ /// The output file, either a wave audio file or some mimi tokens stored as safetensors.
+ out_file: String,
+
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// The model weight file, in safetensor format.
+ #[arg(long)]
+ model: Option<String>,
+}
+
+fn main() -> Result<()> {
+ let args = Args::parse();
+ let device = candle_examples::device(args.cpu)?;
+ let model = match args.model {
+ Some(model) => std::path::PathBuf::from(model),
+ None => Api::new()?
+ .model("kyutai/mimi".to_string())
+ .get("model.safetensors")?,
+ };
+ let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
+ let config = Config::v0_1(None);
+ let mut model = Model::new(config, vb)?;
+
+ let codes = match args.action {
+ Action::CodeToAudio => {
+ let codes = candle::safetensors::load(args.in_file, &device)?;
+ codes.get("codes").expect("no codes in input file").clone()
+ }
+ Action::AudioToCode | Action::AudioToAudio => {
+ let pcm = if args.in_file == "-" {
+ println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
+ let (stream, input_audio) = audio_io::setup_input_stream()?;
+ let mut pcms = vec![];
+ let stdin = std::thread::spawn(|| {
+ let mut s = String::new();
+ std::io::stdin().read_line(&mut s)
+ });
+ while !stdin.is_finished() {
+ let input = input_audio.lock().unwrap().take_all();
+ if input.is_empty() {
+ std::thread::sleep(std::time::Duration::from_millis(100));
+ continue;
+ }
+ pcms.push(input)
+ }
+ drop(stream);
+ pcms.concat()
+ } else {
+ let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
+ if sample_rate != 24_000 {
+ println!("WARNING: mimi uses a 24khz sample rate, input uses {sample_rate}, resampling...");
+ audio_io::resample(&pcm, sample_rate as usize, 24_000)?
+ } else {
+ pcm
+ }
+ };
+ let pcm_len = pcm.len();
+ let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
+ println!("input pcm shape: {:?}", pcm.shape());
+ model.encode(&pcm)?
+ }
+ };
+ println!("codes shape: {:?}", codes.shape());
+
+ match args.action {
+ Action::AudioToCode => {
+ codes.save_safetensors("codes", &args.out_file)?;
+ }
+ Action::AudioToAudio | Action::CodeToAudio => {
+ let pcm = model.decode(&codes)?;
+ println!("output pcm shape: {:?}", pcm.shape());
+ let pcm = pcm.i(0)?.i(0)?;
+ let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
+ let pcm = pcm.to_vec1::<f32>()?;
+ if args.out_file == "-" {
+ let (stream, ad) = audio_io::setup_output_stream()?;
+ {
+ let mut ad = ad.lock().unwrap();
+ ad.push_samples(&pcm)?;
+ }
+ loop {
+ let ad = ad.lock().unwrap();
+ if ad.is_empty() {
+ break;
+ }
+ // That's very weird, calling thread::sleep here triggers the stream to stop
+ // playing (the callback doesn't seem to be called anymore).
+ // std::thread::sleep(std::time::Duration::from_millis(100));
+ }
+ drop(stream)
+ } else {
+ let mut output = std::fs::File::create(&args.out_file)?;
+ candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
+ }
+ }
+ }
+ Ok(())
+}