diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-02-28 09:22:33 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-28 09:22:33 +0100 |
commit | 60ee5cfd4dbe5893fc16c6addfeeca80f5e2a779 (patch) | |
tree | 58e5288aef92110fea4689fb92cbed8290f56fa4 /candle-examples/examples | |
parent | 56e44aabe34371d643162ea421082b46fe229a3f (diff) | |
download | candle-60ee5cfd4dbe5893fc16c6addfeeca80f5e2a779.tar.gz candle-60ee5cfd4dbe5893fc16c6addfeeca80f5e2a779.tar.bz2 candle-60ee5cfd4dbe5893fc16c6addfeeca80f5e2a779.zip |
Support more modes in the encodec example. (#1777)
* Support more modes in the encodec example.
* Remove the old encodec model from the musicgen bits.
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/encodec/README.md | 20 | ||||
-rw-r--r-- | candle-examples/examples/encodec/main.rs | 133 | ||||
-rw-r--r-- | candle-examples/examples/musicgen/encodec_model.rs | 580 | ||||
-rw-r--r-- | candle-examples/examples/musicgen/main.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/musicgen/musicgen_model.rs | 39 | ||||
-rw-r--r-- | candle-examples/examples/musicgen/nn.rs | 20 |
6 files changed, 153 insertions, 641 deletions
diff --git a/candle-examples/examples/encodec/README.md b/candle-examples/examples/encodec/README.md new file mode 100644 index 00000000..3028fb80 --- /dev/null +++ b/candle-examples/examples/encodec/README.md @@ -0,0 +1,20 @@ +# candle-endocec + +[EnCodec](https://huggingface.co/facebook/encodec_24khz) is a high-quality audio +compression model using an encoder/decoder architecture with residual vector +quantization. + +## Running one example + +```bash +cargo run --example encodec --features symphonia --release -- code-to-audio \ + candle-examples/examples/encodec/jfk-codes.safetensors \ + jfk.wav +``` + +This decodes the EnCodec tokens stored in `jfk-codes.safetensors` and generates +an output wav file containing the audio data. Instead of `code-to-audio` one +can use: +- `audio-to-audio in.mp3 out.wav`: encodes the input audio file then decodes it to a wav file. +- `audio-to-code in.mp3 out.safetensors`: generates a safetensors file + containing EnCodec tokens for the input audio file. diff --git a/candle-examples/examples/encodec/main.rs b/candle-examples/examples/encodec/main.rs index 42f2b3f9..fab33651 100644 --- a/candle-examples/examples/encodec/main.rs +++ b/candle-examples/examples/encodec/main.rs @@ -5,15 +5,85 @@ extern crate intel_mkl_src; extern crate accelerate_src; use anyhow::Result; -use candle::{DType, IndexOp}; +use candle::{DType, IndexOp, Tensor}; use candle_nn::VarBuilder; use candle_transformers::models::encodec::{Config, Model}; -use clap::Parser; +use clap::{Parser, ValueEnum}; use hf_hub::api::sync::Api; +fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>) +where + T: symphonia::core::sample::Sample, + f32: symphonia::core::conv::FromSample<T>, +{ + use symphonia::core::audio::Signal; + use symphonia::core::conv::FromSample; + samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v))) +} + +fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> { + use symphonia::core::audio::{AudioBufferRef, Signal}; + + let src = std::fs::File::open(path)?; + let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default()); + let hint = symphonia::core::probe::Hint::new(); + let meta_opts: symphonia::core::meta::MetadataOptions = Default::default(); + let fmt_opts: symphonia::core::formats::FormatOptions = Default::default(); + let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?; + let mut format = probed.format; + let track = format + .tracks() + .iter() + .find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL) + .expect("no supported audio tracks"); + let mut decoder = symphonia::default::get_codecs() + .make(&track.codec_params, &Default::default()) + .expect("unsupported codec"); + let track_id = track.id; + let sample_rate = track.codec_params.sample_rate.unwrap_or(0); + let mut pcm_data = Vec::new(); + while let Ok(packet) = format.next_packet() { + while !format.metadata().is_latest() { + format.metadata().pop(); + } + if packet.track_id() != track_id { + continue; + } + match decoder.decode(&packet)? { + AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)), + AudioBufferRef::U8(data) => conv(&mut pcm_data, data), + AudioBufferRef::U16(data) => conv(&mut pcm_data, data), + AudioBufferRef::U24(data) => conv(&mut pcm_data, data), + AudioBufferRef::U32(data) => conv(&mut pcm_data, data), + AudioBufferRef::S8(data) => conv(&mut pcm_data, data), + AudioBufferRef::S16(data) => conv(&mut pcm_data, data), + AudioBufferRef::S24(data) => conv(&mut pcm_data, data), + AudioBufferRef::S32(data) => conv(&mut pcm_data, data), + AudioBufferRef::F64(data) => conv(&mut pcm_data, data), + } + } + Ok((pcm_data, sample_rate)) +} + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum Action { + AudioToAudio, + AudioToCode, + CodeToAudio, +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { + /// The action to be performed, specifies the format for the input and output data. + action: Action, + + /// The input file, either an audio file or some encodec tokens stored as safetensors. + in_file: String, + + /// The output file, either a wave audio file or some encodec tokens stored as safetensors. + out_file: String, + /// Run on CPU rather than on GPU. #[arg(long)] cpu: bool, @@ -21,18 +91,6 @@ struct Args { /// The model weight file, in safetensor format. #[arg(long)] model: Option<String>, - - /// Input file as a safetensors containing the encodec tokens. - #[arg(long)] - code_file: String, - - /// Output file that will be generated in wav format. - #[arg(long)] - out: String, - - /// Do another step of encoding the PCM data and and decoding the resulting codes. - #[arg(long)] - roundtrip: bool, } fn main() -> Result<()> { @@ -48,25 +106,36 @@ fn main() -> Result<()> { let config = Config::default(); let model = Model::new(&config, vb)?; - let codes = candle::safetensors::load(args.code_file, &device)?; - let codes = codes.get("codes").expect("no codes in input file").i(0)?; - println!("codes shape: {:?}", codes.shape()); - let pcm = model.decode(&codes)?; - println!("pcm shape: {:?}", pcm.shape()); - - let pcm = if args.roundtrip { - let codes = model.encode(&pcm)?; - println!("second step codes shape: {:?}", pcm.shape()); - let pcm = model.decode(&codes)?; - println!("second step pcm shape: {:?}", pcm.shape()); - pcm - } else { - pcm + let codes = match args.action { + Action::CodeToAudio => { + let codes = candle::safetensors::load(args.in_file, &device)?; + let codes = codes.get("codes").expect("no codes in input file").i(0)?; + codes + } + Action::AudioToCode | Action::AudioToAudio => { + let (pcm, sample_rate) = pcm_decode(args.in_file)?; + if sample_rate != 24_000 { + println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}") + } + let pcm_len = pcm.len(); + let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?; + println!("input pcm shape: {:?}", pcm.shape()); + model.encode(&pcm)? + } }; + println!("codes shape: {:?}", codes.shape()); - let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?; - let mut output = std::fs::File::create(&args.out)?; - candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?; - + match args.action { + Action::AudioToCode => { + codes.save_safetensors("codes", &args.out_file)?; + } + Action::AudioToAudio | Action::CodeToAudio => { + let pcm = model.decode(&codes)?; + println!("output pcm shape: {:?}", pcm.shape()); + let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?; + let mut output = std::fs::File::create(&args.out_file)?; + candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?; + } + } Ok(()) } diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs deleted file mode 100644 index 60149e45..00000000 --- a/candle-examples/examples/musicgen/encodec_model.rs +++ /dev/null @@ -1,580 +0,0 @@ -use crate::nn::conv1d_weight_norm; -use candle::{DType, IndexOp, Module, Result, Tensor}; -use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder}; - -// Encodec Model -// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py - -#[derive(Debug, Clone, PartialEq)] -enum NormType { - WeightNorm, - TimeGroupNorm, - None, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct Config { - target_bandwidths: Vec<f64>, - sampling_rate: usize, - audio_channels: usize, - normalize: bool, - chunk_length_s: Option<usize>, - overlap: Option<usize>, - hidden_size: usize, - num_filters: usize, - num_residual_layers: usize, - upsampling_ratios: Vec<usize>, - norm_type: NormType, - kernel_size: usize, - last_kernel_size: usize, - residual_kernel_size: usize, - dilation_growth_rate: usize, - use_causal_conv: bool, - pad_mode: &'static str, - compress: usize, - num_lstm_layers: usize, - trim_right_ratio: f64, - codebook_size: usize, - codebook_dim: Option<usize>, - use_conv_shortcut: bool, -} - -impl Default for Config { - fn default() -> Self { - Self { - target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0], - sampling_rate: 24_000, - audio_channels: 1, - normalize: false, - chunk_length_s: None, - overlap: None, - hidden_size: 128, - num_filters: 32, - num_residual_layers: 1, - upsampling_ratios: vec![8, 5, 4, 2], - norm_type: NormType::WeightNorm, - kernel_size: 7, - last_kernel_size: 7, - residual_kernel_size: 3, - dilation_growth_rate: 2, - use_causal_conv: true, - pad_mode: "reflect", - compress: 2, - num_lstm_layers: 2, - trim_right_ratio: 1.0, - codebook_size: 1024, - codebook_dim: None, - use_conv_shortcut: true, - } - } -} - -impl Config { - // https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6 - pub fn musicgen_small() -> Self { - Self { - audio_channels: 1, - chunk_length_s: None, - codebook_dim: Some(128), - codebook_size: 2048, - compress: 2, - dilation_growth_rate: 2, - hidden_size: 128, - kernel_size: 7, - last_kernel_size: 7, - norm_type: NormType::WeightNorm, - normalize: false, - num_filters: 64, - num_lstm_layers: 2, - num_residual_layers: 1, - overlap: None, - pad_mode: "reflect", - residual_kernel_size: 3, - sampling_rate: 32_000, - target_bandwidths: vec![2.2], - trim_right_ratio: 1.0, - upsampling_ratios: vec![8, 5, 4, 4], - use_causal_conv: false, - use_conv_shortcut: false, - } - } - - fn codebook_dim(&self) -> usize { - self.codebook_dim.unwrap_or(self.codebook_size) - } - - fn frame_rate(&self) -> usize { - let hop_length: usize = self.upsampling_ratios.iter().product(); - (self.sampling_rate + hop_length - 1) / hop_length - } - - fn num_quantizers(&self) -> usize { - let num = 1000f64 - * self - .target_bandwidths - .last() - .expect("empty target_bandwidths"); - (num as usize) / (self.frame_rate() * 10) - } -} - -// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340 -#[derive(Debug)] -struct EncodecEuclideanCodebook { - inited: Tensor, - cluster_size: Tensor, - embed: Tensor, - embed_avg: Tensor, -} - -impl EncodecEuclideanCodebook { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let inited = vb.get(1, "inited")?; - let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?; - let e_shape = (cfg.codebook_size, cfg.codebook_dim()); - let embed = vb.get(e_shape, "embed")?; - let embed_avg = vb.get(e_shape, "embed_avg")?; - Ok(Self { - inited, - cluster_size, - embed, - embed_avg, - }) - } - - fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> { - let quantize = self.embed.embedding(embed_ind)?; - Ok(quantize) - } -} - -#[derive(Debug)] -struct EncodecVectorQuantization { - codebook: EncodecEuclideanCodebook, -} - -impl EncodecVectorQuantization { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let codebook = EncodecEuclideanCodebook::load(vb.pp("codebook"), cfg)?; - Ok(Self { codebook }) - } - - fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> { - let quantize = self.codebook.decode(embed_ind)?; - let quantize = quantize.transpose(1, 2)?; - Ok(quantize) - } -} - -#[derive(Debug)] -struct EncodecResidualVectorQuantizer { - layers: Vec<EncodecVectorQuantization>, -} - -impl EncodecResidualVectorQuantizer { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let vb = &vb.pp("layers"); - let layers = (0..cfg.num_quantizers()) - .map(|i| EncodecVectorQuantization::load(vb.pp(&i.to_string()), cfg)) - .collect::<Result<Vec<_>>>()?; - Ok(Self { layers }) - } - - fn decode(&self, codes: &Tensor) -> Result<Tensor> { - let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?; - if codes.dim(0)? != self.layers.len() { - candle::bail!( - "codes shape {:?} does not match the number of quantization layers {}", - codes.shape(), - self.layers.len() - ) - } - for (i, layer) in self.layers.iter().enumerate() { - let quantized = layer.decode(&codes.i(i)?)?; - quantized_out = quantized.broadcast_add(&quantized_out)?; - } - Ok(quantized_out) - } -} - -// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226 -#[derive(Debug)] -struct EncodecLSTM { - layers: Vec<candle_nn::LSTM>, -} - -impl EncodecLSTM { - fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> { - let vb = &vb.pp("lstm"); - let mut layers = vec![]; - for layer_idx in 0..cfg.num_lstm_layers { - let config = candle_nn::LSTMConfig { - layer_idx, - ..Default::default() - }; - let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?; - layers.push(lstm) - } - Ok(Self { layers }) - } -} - -impl Module for EncodecLSTM { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - use candle_nn::RNN; - let mut xs = xs.clone(); - for layer in self.layers.iter() { - let states = layer.seq(&xs)?; - xs = layer.states_to_tensor(&states)?; - } - Ok(xs) - } -} - -#[derive(Debug)] -struct EncodecConvTranspose1d { - weight_g: Tensor, - weight_v: Tensor, - bias: Tensor, -} - -impl EncodecConvTranspose1d { - fn load( - in_c: usize, - out_c: usize, - k: usize, - _stride: usize, - vb: VarBuilder, - _cfg: &Config, - ) -> Result<Self> { - let vb = &vb.pp("conv"); - let weight_g = vb.get((in_c, 1, 1), "weight_g")?; - let weight_v = vb.get((in_c, out_c, k), "weight_v")?; - let bias = vb.get(out_c, "bias")?; - Ok(Self { - weight_g, - weight_v, - bias, - }) - } -} - -impl Module for EncodecConvTranspose1d { - fn forward(&self, _xs: &Tensor) -> Result<Tensor> { - todo!() - } -} - -#[derive(Debug)] -struct EncodecConv1d { - causal: bool, - conv: Conv1d, - norm: Option<candle_nn::GroupNorm>, -} - -impl EncodecConv1d { - fn load( - in_c: usize, - out_c: usize, - kernel_size: usize, - stride: usize, - vb: VarBuilder, - cfg: &Config, - ) -> Result<Self> { - let conv = match cfg.norm_type { - NormType::WeightNorm => conv1d_weight_norm( - in_c, - out_c, - kernel_size, - Conv1dConfig { - padding: 0, - stride, - groups: 1, - dilation: 1, - }, - vb.pp("conv"), - )?, - NormType::None | NormType::TimeGroupNorm => conv1d( - in_c, - out_c, - kernel_size, - Conv1dConfig { - padding: 0, - stride, - groups: 1, - dilation: 1, - }, - vb.pp("conv"), - )?, - }; - let norm = match cfg.norm_type { - NormType::None | NormType::WeightNorm => None, - NormType::TimeGroupNorm => { - let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?; - Some(gn) - } - }; - Ok(Self { - causal: cfg.use_causal_conv, - conv, - norm, - }) - } -} - -impl Module for EncodecConv1d { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - // TODO: padding, depending on causal. - let xs = self.conv.forward(xs)?; - match &self.norm { - None => Ok(xs), - Some(norm) => xs.apply(norm), - } - } -} - -#[derive(Debug)] -struct EncodecResnetBlock { - block_conv1: EncodecConv1d, - block_conv2: EncodecConv1d, - shortcut: Option<EncodecConv1d>, -} - -impl EncodecResnetBlock { - fn load(dim: usize, dilations: &[usize], vb: VarBuilder, cfg: &Config) -> Result<Self> { - let h = dim / cfg.compress; - let mut layer = Layer::new(vb.pp("block")); - if dilations.len() != 2 { - candle::bail!("expected dilations of size 2") - } - // TODO: Apply dilations! - layer.inc(); - let block_conv1 = - EncodecConv1d::load(dim, h, cfg.residual_kernel_size, 1, layer.next(), cfg)?; - layer.inc(); - let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, layer.next(), cfg)?; - let shortcut = if cfg.use_conv_shortcut { - let conv = EncodecConv1d::load(dim, dim, 1, 1, vb.pp("shortcut"), cfg)?; - Some(conv) - } else { - None - }; - Ok(Self { - block_conv1, - block_conv2, - shortcut, - }) - } -} - -impl Module for EncodecResnetBlock { - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let residual = xs.clone(); - let xs = xs.elu(1.)?; - let xs = self.block_conv1.forward(&xs)?; - let xs = xs.elu(1.)?; - let xs = self.block_conv2.forward(&xs)?; - let xs = match &self.shortcut { - None => (xs + residual)?, - Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?, - }; - Ok(xs) - } -} - -struct Layer<'a> { - vb: VarBuilder<'a>, - cnt: usize, -} - -impl<'a> Layer<'a> { - fn new(vb: VarBuilder<'a>) -> Self { - Self { vb, cnt: 0 } - } - - fn inc(&mut self) { - self.cnt += 1; - } - - fn next(&mut self) -> VarBuilder { - let vb = self.vb.pp(&self.cnt.to_string()); - self.cnt += 1; - vb - } -} - -#[derive(Debug)] -struct EncodecEncoder { - init_conv: EncodecConv1d, - sampling_layers: Vec<(Vec<EncodecResnetBlock>, EncodecConv1d)>, - final_lstm: EncodecLSTM, - final_conv: EncodecConv1d, -} - -impl EncodecEncoder { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let mut layer = Layer::new(vb.pp("layers")); - let init_conv = EncodecConv1d::load( - cfg.audio_channels, - cfg.num_filters, - cfg.kernel_size, - 1, - layer.next(), - cfg, - )?; - let mut sampling_layers = vec![]; - let mut scaling = 1; - for &ratio in cfg.upsampling_ratios.iter().rev() { - let current_scale = scaling * cfg.num_filters; - let mut resnets = vec![]; - for j in 0..(cfg.num_residual_layers as u32) { - let resnet = EncodecResnetBlock::load( - current_scale, - &[cfg.dilation_growth_rate.pow(j), 1], - layer.next(), - cfg, - )?; - resnets.push(resnet) - } - layer.inc(); // ELU - let conv1d = EncodecConv1d::load( - current_scale, - current_scale * 2, - ratio * 2, - ratio, - layer.next(), - cfg, - )?; - sampling_layers.push((resnets, conv1d)); - scaling *= 2; - } - let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?; - layer.inc(); // ELU - let final_conv = EncodecConv1d::load( - cfg.num_filters * scaling, - cfg.hidden_size, - cfg.last_kernel_size, - 1, - layer.next(), - cfg, - )?; - Ok(Self { - init_conv, - sampling_layers, - final_conv, - final_lstm, - }) - } - - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let mut xs = xs.apply(&self.init_conv)?; - for (resnets, conv) in self.sampling_layers.iter() { - for resnet in resnets.iter() { - xs = xs.apply(resnet)?; - } - xs = xs.elu(1.0)?.apply(conv)?; - } - xs.apply(&self.final_lstm)? - .elu(1.0)? - .apply(&self.final_conv) - } -} - -#[derive(Debug)] -struct EncodecDecoder { - init_conv: EncodecConv1d, - init_lstm: EncodecLSTM, - sampling_layers: Vec<(EncodecConvTranspose1d, Vec<EncodecResnetBlock>)>, - final_conv: EncodecConv1d, -} - -impl EncodecDecoder { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let mut layer = Layer::new(vb.pp("layers")); - let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32); - let init_conv = EncodecConv1d::load( - cfg.hidden_size, - cfg.num_filters * scaling, - cfg.last_kernel_size, - 1, - layer.next(), - cfg, - )?; - let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?; - let mut sampling_layers = vec![]; - for &ratio in cfg.upsampling_ratios.iter() { - let current_scale = scaling * cfg.num_filters; - layer.inc(); // ELU - let conv1d = EncodecConvTranspose1d::load( - current_scale, - current_scale / 2, - ratio * 2, - ratio, - layer.next(), - cfg, - )?; - let mut resnets = vec![]; - for j in 0..(cfg.num_residual_layers as u32) { - let resnet = EncodecResnetBlock::load( - current_scale / 2, - &[cfg.dilation_growth_rate.pow(j), 1], - layer.next(), - cfg, - )?; - resnets.push(resnet) - } - sampling_layers.push((conv1d, resnets)); - scaling /= 2; - } - layer.inc(); // ELU - let final_conv = EncodecConv1d::load( - cfg.num_filters, - cfg.audio_channels, - cfg.last_kernel_size, - 1, - layer.next(), - cfg, - )?; - Ok(Self { - init_conv, - init_lstm, - sampling_layers, - final_conv, - }) - } - - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?; - for (conv, resnets) in self.sampling_layers.iter() { - xs = xs.elu(1.)?.apply(conv)?; - for resnet in resnets.iter() { - xs = xs.apply(resnet)? - } - } - xs.elu(1.)?.apply(&self.final_conv) - } -} - -#[derive(Debug)] -pub struct EncodecModel { - encoder: EncodecEncoder, - decoder: EncodecDecoder, - quantizer: EncodecResidualVectorQuantizer, -} - -impl EncodecModel { - pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let encoder = EncodecEncoder::load(vb.pp("encoder"), cfg)?; - let decoder = EncodecDecoder::load(vb.pp("decoder"), cfg)?; - let quantizer = EncodecResidualVectorQuantizer::load(vb.pp("quantizer"), cfg)?; - Ok(Self { - encoder, - decoder, - quantizer, - }) - } - - pub fn forward(&self, _xs: &Tensor) -> Result<Tensor> { - todo!() - } -} diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs index a39cfec2..7e081429 100644 --- a/candle-examples/examples/musicgen/main.rs +++ b/candle-examples/examples/musicgen/main.rs @@ -10,9 +10,7 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -mod encodec_model; mod musicgen_model; -mod nn; use musicgen_model::{GenConfig, MusicgenForConditionalGeneration}; diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs index c6b52fde..03e96614 100644 --- a/candle-examples/examples/musicgen/musicgen_model.rs +++ b/candle-examples/examples/musicgen/musicgen_model.rs @@ -1,10 +1,9 @@ -use crate::encodec_model; use candle::{DType, Device, Result, Tensor, D}; use candle_nn::{ embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module, VarBuilder, }; -use candle_transformers::models::t5; +use candle_transformers::models::{encodec, t5}; // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83 #[derive(Debug, Clone, PartialEq)] @@ -372,7 +371,7 @@ impl MusicgenForCausalLM { #[derive(Debug)] pub struct MusicgenForConditionalGeneration { pub text_encoder: t5::T5EncoderModel, - pub audio_encoder: crate::encodec_model::EncodecModel, + pub audio_encoder: encodec::Model, pub decoder: MusicgenForCausalLM, cfg: GenConfig, } @@ -381,15 +380,42 @@ pub struct MusicgenForConditionalGeneration { pub struct GenConfig { musicgen: Config, t5: t5::Config, - encodec: crate::encodec_model::Config, + encodec: encodec::Config, } impl GenConfig { pub fn small() -> Self { + // https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6 + let encodec = encodec::Config { + audio_channels: 1, + chunk_length_s: None, + codebook_dim: Some(128), + codebook_size: 2048, + compress: 2, + dilation_growth_rate: 2, + hidden_size: 128, + kernel_size: 7, + last_kernel_size: 7, + norm_type: encodec::NormType::WeightNorm, + normalize: false, + num_filters: 64, + num_lstm_layers: 2, + num_residual_layers: 1, + overlap: None, + // This should be Reflect and not Replicate but Reflect does not work yet. + pad_mode: encodec::PadMode::Replicate, + residual_kernel_size: 3, + sampling_rate: 32_000, + target_bandwidths: vec![2.2], + trim_right_ratio: 1.0, + upsampling_ratios: vec![8, 5, 4, 4], + use_causal_conv: false, + use_conv_shortcut: false, + }; Self { musicgen: Config::musicgen_small(), t5: t5::Config::musicgen_small(), - encodec: encodec_model::Config::musicgen_small(), + encodec, } } } @@ -401,8 +427,7 @@ impl MusicgenForConditionalGeneration { pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> { let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?; - let audio_encoder = - encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?; + let audio_encoder = encodec::Model::new(&cfg.encodec, vb.pp("audio_encoder"))?; let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?; Ok(Self { text_encoder, diff --git a/candle-examples/examples/musicgen/nn.rs b/candle-examples/examples/musicgen/nn.rs deleted file mode 100644 index 282b3a05..00000000 --- a/candle-examples/examples/musicgen/nn.rs +++ /dev/null @@ -1,20 +0,0 @@ -use candle::Result; -use candle_nn::{Conv1d, Conv1dConfig, VarBuilder}; - -// Applies weight norm for inference by recomputing the weight tensor. This -// does not apply to training. -// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html -pub fn conv1d_weight_norm( - in_c: usize, - out_c: usize, - kernel_size: usize, - config: Conv1dConfig, - vb: VarBuilder, -) -> Result<Conv1d> { - let weight_g = vb.get((out_c, 1, 1), "weight_g")?; - let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?; - let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?; - let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?; - let bias = vb.get(out_c, "bias")?; - Ok(Conv1d::new(weight, Some(bias), config)) -} |