diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-08-19 07:59:51 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-19 08:59:51 +0200 |
commit | 236b29ff1555db82fdb78c1be8741c0ac37859d1 (patch) | |
tree | 89c30dcf340e01fce46ed6194877878637cbbcea /candle-transformers | |
parent | 58197e189657b6587a254882abdb232e83e86848 (diff) | |
download | candle-236b29ff1555db82fdb78c1be8741c0ac37859d1.tar.gz candle-236b29ff1555db82fdb78c1be8741c0ac37859d1.tar.bz2 candle-236b29ff1555db82fdb78c1be8741c0ac37859d1.zip |
Add the DAC model. (#2433)
* Add the DAC model.
* More quantization support.
* Handle DAC decoding.
* Plug the DAC decoding in parler-tts.
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/dac.rs | 376 | ||||
-rw-r--r-- | candle-transformers/src/models/encodec.rs | 2 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/parler_tts.rs | 5 |
4 files changed, 383 insertions, 1 deletions
diff --git a/candle-transformers/src/models/dac.rs b/candle-transformers/src/models/dac.rs new file mode 100644 index 00000000..fa6c8c71 --- /dev/null +++ b/candle-transformers/src/models/dac.rs @@ -0,0 +1,376 @@ +/// Adapted from https://github.com/descriptinc/descript-audio-codec +use crate::models::encodec; +use candle::{IndexOp, Result, Tensor, D}; +use candle_nn::{Conv1d, Conv1dConfig, ConvTranspose1d, ConvTranspose1dConfig, VarBuilder}; + +#[derive(serde::Deserialize, Debug, Clone)] +pub struct Config { + pub num_codebooks: usize, + pub model_bitrate: u32, + pub codebook_size: usize, + pub latent_dim: usize, + pub frame_rate: u32, + pub sampling_rate: u32, +} + +#[derive(Debug, Clone)] +pub struct Snake1d { + alpha: Tensor, +} + +impl Snake1d { + pub fn new(channels: usize, vb: VarBuilder) -> Result<Self> { + let alpha = vb.get((1, channels, 1), "alpha")?; + Ok(Self { alpha }) + } +} + +impl candle::Module for Snake1d { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let xs_shape = xs.shape(); + let xs = xs.flatten_from(2)?; + let sin = self.alpha.broadcast_mul(&xs)?.sin()?; + let sin = (&sin * &sin)?; + (xs + (&self.alpha + 1e-9)?.recip()?.broadcast_mul(&sin)?)?.reshape(xs_shape) + } +} + +#[derive(Debug, Clone)] +pub struct ResidualUnit { + snake1: Snake1d, + conv1: Conv1d, + snake2: Snake1d, + conv2: Conv1d, +} + +impl ResidualUnit { + pub fn new(dim: usize, dilation: usize, vb: VarBuilder) -> Result<Self> { + let pad = ((7 - 1) * dilation) / 2; + let vb = vb.pp("block"); + let snake1 = Snake1d::new(dim, vb.pp(0))?; + let cfg1 = Conv1dConfig { + dilation, + padding: pad, + ..Default::default() + }; + let conv1 = encodec::conv1d_weight_norm(dim, dim, 7, cfg1, vb.pp(1))?; + let snake2 = Snake1d::new(dim, vb.pp(2))?; + let conv2 = encodec::conv1d_weight_norm(dim, dim, 1, Default::default(), vb.pp(3))?; + Ok(Self { + snake1, + conv1, + snake2, + conv2, + }) + } +} + +impl candle::Module for ResidualUnit { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let ys = xs + .apply(&self.snake1)? + .apply(&self.conv1)? + .apply(&self.snake2)? + .apply(&self.conv2)?; + let pad = (xs.dim(D::Minus1)? - ys.dim(D::Minus1)?) / 2; + if pad > 0 { + &ys + xs.narrow(D::Minus1, pad, ys.dim(D::Minus1)?) + } else { + ys + xs + } + } +} + +#[derive(Debug, Clone)] +pub struct EncoderBlock { + res1: ResidualUnit, + res2: ResidualUnit, + res3: ResidualUnit, + snake1: Snake1d, + conv1: Conv1d, +} + +impl EncoderBlock { + pub fn new(dim: usize, stride: usize, vb: VarBuilder) -> Result<Self> { + let vb = vb.pp("block"); + let res1 = ResidualUnit::new(dim / 2, 1, vb.pp(0))?; + let res2 = ResidualUnit::new(dim / 2, 3, vb.pp(1))?; + let res3 = ResidualUnit::new(dim / 2, 9, vb.pp(2))?; + let snake1 = Snake1d::new(dim / 2, vb.pp(3))?; + let cfg1 = Conv1dConfig { + stride, + padding: (stride + 1) / 2, + ..Default::default() + }; + let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?; + Ok(Self { + res1, + res2, + res3, + snake1, + conv1, + }) + } +} + +impl candle::Module for EncoderBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + xs.apply(&self.res1)? + .apply(&self.res2)? + .apply(&self.res3)? + .apply(&self.snake1)? + .apply(&self.conv1) + } +} + +#[derive(Debug, Clone)] +pub struct Encoder { + conv1: Conv1d, + blocks: Vec<EncoderBlock>, + snake1: Snake1d, + conv2: Conv1d, +} + +impl candle::Module for Encoder { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = xs.apply(&self.conv1)?; + for block in self.blocks.iter() { + xs = xs.apply(block)? + } + xs.apply(&self.snake1)?.apply(&self.conv2) + } +} + +impl Encoder { + pub fn new( + mut d_model: usize, + strides: &[usize], + d_latent: usize, + vb: VarBuilder, + ) -> Result<Self> { + let vb = vb.pp("block"); + let cfg1 = Conv1dConfig { + padding: 3, + ..Default::default() + }; + let conv1 = encodec::conv1d_weight_norm(1, d_model, 7, cfg1, vb.pp(0))?; + let mut blocks = Vec::with_capacity(strides.len()); + for (block_idx, stride) in strides.iter().enumerate() { + d_model *= 2; + let block = EncoderBlock::new(d_model, *stride, vb.pp(block_idx + 1))?; + blocks.push(block) + } + let snake1 = Snake1d::new(d_model, vb.pp(strides.len() + 1))?; + let cfg2 = Conv1dConfig { + padding: 1, + ..Default::default() + }; + let conv2 = + encodec::conv1d_weight_norm(d_model, d_latent, 3, cfg2, vb.pp(strides.len() + 2))?; + Ok(Self { + conv1, + blocks, + snake1, + conv2, + }) + } +} + +#[derive(Debug, Clone)] +pub struct DecoderBlock { + snake1: Snake1d, + conv_tr1: ConvTranspose1d, + res1: ResidualUnit, + res2: ResidualUnit, + res3: ResidualUnit, +} + +impl DecoderBlock { + pub fn new(in_dim: usize, out_dim: usize, stride: usize, vb: VarBuilder) -> Result<Self> { + let vb = vb.pp("block"); + let snake1 = Snake1d::new(in_dim, vb.pp(0))?; + let cfg = ConvTranspose1dConfig { + stride, + padding: (stride + 1) / 2, + ..Default::default() + }; + let conv_tr1 = encodec::conv_transpose1d_weight_norm( + in_dim, + out_dim, + 2 * stride, + true, + cfg, + vb.pp(1), + )?; + let res1 = ResidualUnit::new(out_dim, 1, vb.pp(2))?; + let res2 = ResidualUnit::new(out_dim, 3, vb.pp(3))?; + let res3 = ResidualUnit::new(out_dim, 9, vb.pp(4))?; + Ok(Self { + snake1, + conv_tr1, + res1, + res2, + res3, + }) + } +} + +impl candle_nn::Module for DecoderBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + xs.apply(&self.snake1)? + .apply(&self.conv_tr1)? + .apply(&self.res1)? + .apply(&self.res2)? + .apply(&self.res3) + } +} + +#[derive(Debug, Clone)] +pub struct Decoder { + conv1: Conv1d, + blocks: Vec<DecoderBlock>, + snake1: Snake1d, + conv2: Conv1d, +} + +impl Decoder { + pub fn new( + in_c: usize, + mut channels: usize, + rates: &[usize], + d_out: usize, + vb: VarBuilder, + ) -> Result<Self> { + let vb = vb.pp("model"); + let cfg1 = Conv1dConfig { + padding: 3, + ..Default::default() + }; + let conv1 = encodec::conv1d_weight_norm(in_c, channels, 7, cfg1, vb.pp(0))?; + let mut blocks = Vec::with_capacity(rates.len()); + for (idx, stride) in rates.iter().enumerate() { + let block = DecoderBlock::new(channels, channels / 2, *stride, vb.pp(idx + 1))?; + channels /= 2; + blocks.push(block) + } + let snake1 = Snake1d::new(channels, vb.pp(rates.len() + 1))?; + let conv2 = encodec::conv1d_weight_norm(channels, d_out, 7, cfg1, vb.pp(rates.len() + 2))?; + Ok(Self { + conv1, + blocks, + snake1, + conv2, + }) + } +} + +impl candle::Module for Decoder { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = xs.apply(&self.conv1)?; + for block in self.blocks.iter() { + xs = xs.apply(block)? + } + xs.apply(&self.snake1)?.apply(&self.conv2) + } +} + +#[allow(unused)] +#[derive(Clone, Debug)] +pub struct VectorQuantizer { + in_proj: Conv1d, + out_proj: Conv1d, + codebook: candle_nn::Embedding, +} + +impl VectorQuantizer { + pub fn new(in_dim: usize, cb_size: usize, cb_dim: usize, vb: VarBuilder) -> Result<Self> { + let in_proj = + encodec::conv1d_weight_norm(in_dim, cb_dim, 1, Default::default(), vb.pp("in_proj"))?; + let out_proj = + encodec::conv1d_weight_norm(cb_dim, in_dim, 1, Default::default(), vb.pp("out_proj"))?; + let codebook = candle_nn::embedding(cb_size, cb_dim, vb.pp("codebook"))?; + Ok(Self { + in_proj, + out_proj, + codebook, + }) + } + + pub fn embed_code(&self, embed_id: &Tensor) -> Result<Tensor> { + embed_id.apply(&self.codebook) + } + + pub fn decode_code(&self, embed_id: &Tensor) -> Result<Tensor> { + self.embed_code(embed_id)?.transpose(1, 2) + } +} + +#[derive(Clone, Debug)] +pub struct ResidualVectorQuantizer { + quantizers: Vec<VectorQuantizer>, +} + +impl ResidualVectorQuantizer { + pub fn new( + input_dim: usize, + n_codebooks: usize, + cb_size: usize, + cb_dim: usize, + vb: VarBuilder, + ) -> Result<Self> { + let vb = &vb.pp("quantizers"); + let quantizers = (0..n_codebooks) + .map(|i| VectorQuantizer::new(input_dim, cb_size, cb_dim, vb.pp(i))) + .collect::<Result<Vec<_>>>()?; + Ok(Self { quantizers }) + } + + pub fn from_codes(&self, codes: &Tensor) -> Result<Tensor> { + let mut sum = None; + for (idx, quantizer) in self.quantizers.iter().enumerate() { + let z_p_i = quantizer.decode_code(&codes.i((.., idx))?)?; + let z_q_i = z_p_i.apply(&quantizer.out_proj)?; + let s = match sum { + None => z_q_i, + Some(s) => (s + z_q_i)?, + }; + sum = Some(s) + } + match sum { + Some(s) => Ok(s), + None => candle::bail!("empty codebooks"), + } + } +} + +#[derive(Debug, Clone)] +pub struct Model { + pub encoder: Encoder, + pub quantizer: ResidualVectorQuantizer, + pub decoder: Decoder, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let vb = vb.pp("model"); + let encoder = Encoder::new(64, &[2, 4, 8, 8], cfg.latent_dim, vb.pp("encoder"))?; + let quantizer = ResidualVectorQuantizer::new( + cfg.latent_dim, + cfg.num_codebooks, + cfg.codebook_size, + 8, + vb.pp("quantizer"), + )?; + let decoder = Decoder::new(cfg.latent_dim, 1536, &[8, 8, 4, 2], 1, vb.pp("decoder"))?; + Ok(Self { + encoder, + decoder, + quantizer, + }) + } + + pub fn decode_codes(&self, audio_codes: &Tensor) -> Result<Tensor> { + let audio_values = self.quantizer.from_codes(audio_codes)?; + audio_values.apply(&self.decoder) + } +} diff --git a/candle-transformers/src/models/encodec.rs b/candle-transformers/src/models/encodec.rs index fb70fb52..ba6686f6 100644 --- a/candle-transformers/src/models/encodec.rs +++ b/candle-transformers/src/models/encodec.rs @@ -136,7 +136,7 @@ pub fn conv1d_weight_norm( Ok(Conv1d::new(weight, Some(bias), config)) } -fn conv_transpose1d_weight_norm( +pub fn conv_transpose1d_weight_norm( in_c: usize, out_c: usize, kernel_size: usize, diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 83d13a7b..cc83cf7b 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -9,6 +9,7 @@ pub mod clip; pub mod codegeex4_9b; pub mod convmixer; pub mod convnext; +pub mod dac; pub mod depth_anything_v2; pub mod dinov2; pub mod dinov2reg4; diff --git a/candle-transformers/src/models/parler_tts.rs b/candle-transformers/src/models/parler_tts.rs index 9c66c93a..16023a7c 100644 --- a/candle-transformers/src/models/parler_tts.rs +++ b/candle-transformers/src/models/parler_tts.rs @@ -31,6 +31,7 @@ pub struct Config { pub decoder: DecoderConfig, pub text_encoder: t5::Config, pub vocab_size: usize, + pub audio_encoder: crate::models::dac::Config, } #[derive(Debug, Clone)] @@ -325,6 +326,7 @@ pub struct Model { pub text_encoder: t5::T5EncoderModel, pub decoder_start_token_id: u32, pub pad_token_id: u32, + pub audio_encoder: crate::models::dac::Model, } impl Model { @@ -347,6 +349,8 @@ impl Model { } else { None }; + let audio_encoder = + crate::models::dac::Model::new(&cfg.audio_encoder, vb.pp("audio_encoder"))?; Ok(Self { decoder, text_encoder, @@ -354,6 +358,7 @@ impl Model { enc_to_dec_proj, decoder_start_token_id: cfg.decoder_start_token_id, pad_token_id: cfg.pad_token_id, + audio_encoder, }) } |