diff options
Diffstat (limited to 'candle-examples/examples/encodec/main.rs')
-rw-r--r-- | candle-examples/examples/encodec/main.rs | 133 |
1 files changed, 101 insertions, 32 deletions
diff --git a/candle-examples/examples/encodec/main.rs b/candle-examples/examples/encodec/main.rs index 42f2b3f9..fab33651 100644 --- a/candle-examples/examples/encodec/main.rs +++ b/candle-examples/examples/encodec/main.rs @@ -5,15 +5,85 @@ extern crate intel_mkl_src; extern crate accelerate_src; use anyhow::Result; -use candle::{DType, IndexOp}; +use candle::{DType, IndexOp, Tensor}; use candle_nn::VarBuilder; use candle_transformers::models::encodec::{Config, Model}; -use clap::Parser; +use clap::{Parser, ValueEnum}; use hf_hub::api::sync::Api; +fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>) +where + T: symphonia::core::sample::Sample, + f32: symphonia::core::conv::FromSample<T>, +{ + use symphonia::core::audio::Signal; + use symphonia::core::conv::FromSample; + samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v))) +} + +fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> { + use symphonia::core::audio::{AudioBufferRef, Signal}; + + let src = std::fs::File::open(path)?; + let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default()); + let hint = symphonia::core::probe::Hint::new(); + let meta_opts: symphonia::core::meta::MetadataOptions = Default::default(); + let fmt_opts: symphonia::core::formats::FormatOptions = Default::default(); + let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?; + let mut format = probed.format; + let track = format + .tracks() + .iter() + .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL) + .expect("no supported audio tracks"); + let mut decoder = symphonia::default::get_codecs() + .make(&track.codec_params, &Default::default()) + .expect("unsupported codec"); + let track_id = track.id; + let sample_rate = track.codec_params.sample_rate.unwrap_or(0); + let mut pcm_data = Vec::new(); + while let Ok(packet) = format.next_packet() { + while !format.metadata().is_latest() { + format.metadata().pop(); + } + if packet.track_id() != track_id { + continue; + } + match decoder.decode(&packet)? { + AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)), + AudioBufferRef::U8(data) => conv(&mut pcm_data, data), + AudioBufferRef::U16(data) => conv(&mut pcm_data, data), + AudioBufferRef::U24(data) => conv(&mut pcm_data, data), + AudioBufferRef::U32(data) => conv(&mut pcm_data, data), + AudioBufferRef::S8(data) => conv(&mut pcm_data, data), + AudioBufferRef::S16(data) => conv(&mut pcm_data, data), + AudioBufferRef::S24(data) => conv(&mut pcm_data, data), + AudioBufferRef::S32(data) => conv(&mut pcm_data, data), + AudioBufferRef::F64(data) => conv(&mut pcm_data, data), + } + } + Ok((pcm_data, sample_rate)) +} + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum Action { + AudioToAudio, + AudioToCode, + CodeToAudio, +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { + /// The action to be performed, specifies the format for the input and output data. + action: Action, + + /// The input file, either an audio file or some encodec tokens stored as safetensors. + in_file: String, + + /// The output file, either a wave audio file or some encodec tokens stored as safetensors. + out_file: String, + /// Run on CPU rather than on GPU. #[arg(long)] cpu: bool, @@ -21,18 +91,6 @@ struct Args { /// The model weight file, in safetensor format. #[arg(long)] model: Option<String>, - - /// Input file as a safetensors containing the encodec tokens. - #[arg(long)] - code_file: String, - - /// Output file that will be generated in wav format. - #[arg(long)] - out: String, - - /// Do another step of encoding the PCM data and and decoding the resulting codes. - #[arg(long)] - roundtrip: bool, } fn main() -> Result<()> { @@ -48,25 +106,36 @@ fn main() -> Result<()> { let config = Config::default(); let model = Model::new(&config, vb)?; - let codes = candle::safetensors::load(args.code_file, &device)?; - let codes = codes.get("codes").expect("no codes in input file").i(0)?; - println!("codes shape: {:?}", codes.shape()); - let pcm = model.decode(&codes)?; - println!("pcm shape: {:?}", pcm.shape()); - - let pcm = if args.roundtrip { - let codes = model.encode(&pcm)?; - println!("second step codes shape: {:?}", pcm.shape()); - let pcm = model.decode(&codes)?; - println!("second step pcm shape: {:?}", pcm.shape()); - pcm - } else { - pcm + let codes = match args.action { + Action::CodeToAudio => { + let codes = candle::safetensors::load(args.in_file, &device)?; + let codes = codes.get("codes").expect("no codes in input file").i(0)?; + codes + } + Action::AudioToCode | Action::AudioToAudio => { + let (pcm, sample_rate) = pcm_decode(args.in_file)?; + if sample_rate != 24_000 { + println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}") + } + let pcm_len = pcm.len(); + let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?; + println!("input pcm shape: {:?}", pcm.shape()); + model.encode(&pcm)? + } }; + println!("codes shape: {:?}", codes.shape()); - let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?; - let mut output = std::fs::File::create(&args.out)?; - candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?; - + match args.action { + Action::AudioToCode => { + codes.save_safetensors("codes", &args.out_file)?; + } + Action::AudioToAudio | Action::CodeToAudio => { + let pcm = model.decode(&codes)?; + println!("output pcm shape: {:?}", pcm.shape()); + let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?; + let mut output = std::fs::File::create(&args.out_file)?; + candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?; + } + } Ok(()) } |