diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-02-27 22:59:40 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-27 22:59:40 +0100 |
commit | 0c49e95dfb5f25a17340ff1b690a4b4f6cd0e2d1 (patch) | |
tree | c732811778ea6e15c558dcbe35153cd110eb5959 | |
parent | 205767f9ded3d531822d3702442a52b4a320f72e (diff) | |
download | candle-0c49e95dfb5f25a17340ff1b690a4b4f6cd0e2d1.tar.gz candle-0c49e95dfb5f25a17340ff1b690a4b4f6cd0e2d1.tar.bz2 candle-0c49e95dfb5f25a17340ff1b690a4b4f6cd0e2d1.zip |
Encodec model. (#1771)
* Encodec model.
* Fixes.
* Add the padding functions.
* Get the LSTM bit to work.
* Get the encodec model to generate some tokens (decoder only for now).
* Minor tweak.
* Minor tweak.
-rw-r--r-- | candle-examples/examples/encodec/jfk-codes.safetensors | bin | 0 -> 13328 bytes | |||
-rw-r--r-- | candle-examples/examples/encodec/main.rs | 57 | ||||
-rw-r--r-- | candle-examples/src/lib.rs | 1 | ||||
-rw-r--r-- | candle-examples/src/wav.rs | 56 | ||||
-rw-r--r-- | candle-nn/src/rnn.rs | 2 | ||||
-rw-r--r-- | candle-transformers/src/models/encodec.rs | 718 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 |
7 files changed, 834 insertions, 1 deletions
diff --git a/candle-examples/examples/encodec/jfk-codes.safetensors b/candle-examples/examples/encodec/jfk-codes.safetensors Binary files differnew file mode 100644 index 00000000..b8eb2026 --- /dev/null +++ b/candle-examples/examples/encodec/jfk-codes.safetensors diff --git a/candle-examples/examples/encodec/main.rs b/candle-examples/examples/encodec/main.rs new file mode 100644 index 00000000..47a9ba59 --- /dev/null +++ b/candle-examples/examples/encodec/main.rs @@ -0,0 +1,57 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Result; +use candle::{DType, IndexOp}; +use candle_nn::VarBuilder; +use candle_transformers::models::encodec::{Config, Model}; +use clap::Parser; +use hf_hub::api::sync::Api; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// The model weight file, in safetensor format. + #[arg(long)] + model: Option<String>, + + /// Input file as a safetensors containing the encodec tokens. + #[arg(long)] + code_file: String, + + /// Output file that will be generated in wav format. + #[arg(long)] + out: String, +} + +fn main() -> Result<()> { + let args = Args::parse(); + let device = candle_examples::device(args.cpu)?; + let model = match args.model { + Some(model) => std::path::PathBuf::from(model), + None => Api::new()? + .model("facebook/encodec_24khz".to_string()) + .get("model.safetensors")?, + }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }; + let config = Config::default(); + let model = Model::new(&config, vb)?; + + let codes = candle::safetensors::load(args.code_file, &device)?; + let codes = codes.get("codes").expect("no codes in input file").i(0)?; + println!("codes shape: {:?}", codes.shape()); + let pcm = model.decode(&codes)?; + let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?; + + let mut output = std::fs::File::create(&args.out)?; + candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?; + + Ok(()) +} diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index d6dce4a3..7cb8eb01 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -1,6 +1,7 @@ pub mod coco_classes; pub mod imagenet; pub mod token_output_stream; +pub mod wav; use candle::utils::{cuda_is_available, metal_is_available}; use candle::{Device, Result, Tensor}; diff --git a/candle-examples/src/wav.rs b/candle-examples/src/wav.rs new file mode 100644 index 00000000..df98aa14 --- /dev/null +++ b/candle-examples/src/wav.rs @@ -0,0 +1,56 @@ +use std::io::prelude::*; + +pub trait Sample { + fn to_i16(&self) -> i16; +} + +impl Sample for f32 { + fn to_i16(&self) -> i16 { + (self.clamp(-1.0, 1.0) * 32767.0) as i16 + } +} + +impl Sample for f64 { + fn to_i16(&self) -> i16 { + (self.clamp(-1.0, 1.0) * 32767.0) as i16 + } +} + +impl Sample for i16 { + fn to_i16(&self) -> i16 { + *self + } +} + +pub fn write_pcm_as_wav<W: Write, S: Sample>( + w: &mut W, + samples: &[S], + sample_rate: u32, +) -> std::io::Result<()> { + let len = 12u32; // header + let len = len + 24u32; // fmt + let len = len + samples.len() as u32 * 2 + 8; // data + let n_channels = 1u16; + let bytes_per_second = sample_rate * 2 * n_channels as u32; + w.write_all(b"RIFF")?; + w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes + w.write_all(b"WAVE")?; + + // Format block + w.write_all(b"fmt ")?; + w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes + w.write_all(&1u16.to_le_bytes())?; // PCM + w.write_all(&n_channels.to_le_bytes())?; // one channel + w.write_all(&sample_rate.to_le_bytes())?; + w.write_all(&bytes_per_second.to_le_bytes())?; + w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample + w.write_all(&16u16.to_le_bytes())?; // bits per sample + + // Data block + w.write_all(b"data")?; + w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?; + for sample in samples.iter() { + w.write_all(&sample.to_i16().to_le_bytes())? + } + Ok(()) +} diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index 9f144cca..07795eda 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -197,7 +197,7 @@ impl RNN for LSTM { fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor> { let states = states.iter().map(|s| s.h.clone()).collect::<Vec<_>>(); - Tensor::cat(&states, 1) + Tensor::stack(&states, 1) } } diff --git a/candle-transformers/src/models/encodec.rs b/candle-transformers/src/models/encodec.rs new file mode 100644 index 00000000..68f01d87 --- /dev/null +++ b/candle-transformers/src/models/encodec.rs @@ -0,0 +1,718 @@ +#![allow(unused)] +use candle::{DType, IndexOp, Layout, Module, Result, Shape, Tensor, D}; +use candle_nn::{conv1d, Conv1d, Conv1dConfig, ConvTranspose1d, VarBuilder}; + +// Encodec Model +// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py + +#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize)] +pub enum NormType { + WeightNorm, + TimeGroupNorm, + None, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize)] +pub enum PadMode { + Constant, + Reflect, + Replicate, +} + +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +pub struct Config { + pub target_bandwidths: Vec<f64>, + pub sampling_rate: usize, + pub audio_channels: usize, + pub normalize: bool, + pub chunk_length_s: Option<usize>, + pub overlap: Option<usize>, + pub hidden_size: usize, + pub num_filters: usize, + pub num_residual_layers: usize, + pub upsampling_ratios: Vec<usize>, + pub norm_type: NormType, + pub kernel_size: usize, + pub last_kernel_size: usize, + pub residual_kernel_size: usize, + pub dilation_growth_rate: usize, + pub use_causal_conv: bool, + pub pad_mode: PadMode, + pub compress: usize, + pub num_lstm_layers: usize, + pub trim_right_ratio: f64, + pub codebook_size: usize, + pub codebook_dim: Option<usize>, + pub use_conv_shortcut: bool, +} + +impl Default for Config { + fn default() -> Self { + Self { + target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0], + sampling_rate: 24_000, + audio_channels: 1, + normalize: false, + chunk_length_s: None, + overlap: None, + hidden_size: 128, + num_filters: 32, + num_residual_layers: 1, + upsampling_ratios: vec![8, 5, 4, 2], + norm_type: NormType::WeightNorm, + kernel_size: 7, + last_kernel_size: 7, + residual_kernel_size: 3, + dilation_growth_rate: 2, + use_causal_conv: true, + // This should be PadMode::Reflect which is currently unsupported in candle. + pad_mode: PadMode::Replicate, + compress: 2, + num_lstm_layers: 2, + trim_right_ratio: 1.0, + codebook_size: 1024, + codebook_dim: None, + use_conv_shortcut: true, + } + } +} + +impl Config { + fn codebook_dim(&self) -> usize { + self.codebook_dim.unwrap_or(self.hidden_size) + } + + fn frame_rate(&self) -> usize { + let hop_length: usize = self.upsampling_ratios.iter().product(); + (self.sampling_rate + hop_length - 1) / hop_length + } + + fn num_quantizers(&self) -> usize { + let num = 1000f64 + * self + .target_bandwidths + .last() + .expect("empty target_bandwidths"); + (num as usize) / (self.frame_rate() * 10) + } +} + +fn get_extra_padding_for_conv1d( + xs: &Tensor, + k_size: usize, + stride: usize, + padding_total: usize, +) -> Result<usize> { + let len = xs.dim(D::Minus1)?; + let n_frames = (len + padding_total).saturating_sub(k_size) as f64 / stride as f64 + 1.0; + let ideal_len = + ((n_frames.ceil() as usize - 1) * stride + k_size).saturating_sub(padding_total); + Ok(ideal_len.saturating_sub(len)) +} + +fn pad1d(xs: &Tensor, pad_l: usize, pad_r: usize, mode: PadMode) -> Result<Tensor> { + match mode { + PadMode::Constant => xs.pad_with_zeros(D::Minus1, pad_l, pad_r), + PadMode::Reflect => candle::bail!("pad-mode 'reflect' is not supported"), + PadMode::Replicate => xs.pad_with_same(D::Minus1, pad_l, pad_r), + } +} + +// Applies weight norm for inference by recomputing the weight tensor. This +// does not apply to training. +// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html +pub fn conv1d_weight_norm( + in_c: usize, + out_c: usize, + kernel_size: usize, + config: candle_nn::Conv1dConfig, + vb: VarBuilder, +) -> Result<Conv1d> { + let weight_g = vb.get((out_c, 1, 1), "weight_g")?; + let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?; + let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?; + let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?; + let bias = vb.get(out_c, "bias")?; + Ok(Conv1d::new(weight, Some(bias), config)) +} + +fn conv_transpose1d_weight_norm( + in_c: usize, + out_c: usize, + kernel_size: usize, + bias: bool, + config: candle_nn::ConvTranspose1dConfig, + vb: VarBuilder, +) -> Result<ConvTranspose1d> { + let weight_g = vb.get((in_c, 1, 1), "weight_g")?; + let weight_v = vb.get((in_c, out_c, kernel_size), "weight_v")?; + let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?; + let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?; + let bias = if bias { + Some(vb.get(out_c, "bias")?) + } else { + None + }; + Ok(ConvTranspose1d::new(weight, bias, config)) +} + +struct CodebookEncode; + +impl candle::CustomOp2 for CodebookEncode { + fn name(&self) -> &'static str { + "cb" + } + + fn cpu_fwd( + &self, + lhs_storage: &candle::CpuStorage, + lhs_layout: &Layout, + rhs_storage: &candle::CpuStorage, + rhs_layout: &Layout, + ) -> Result<(candle::CpuStorage, Shape)> { + use rayon::prelude::*; + + let (lhs_dim1, lhs_dim2) = lhs_layout.shape().dims2()?; + let (rhs_dim1, rhs_dim2) = rhs_layout.shape().dims2()?; + if lhs_dim2 != rhs_dim2 { + candle::bail!("CodebookEncode, mismatch on last dim, {lhs_layout:?} {rhs_layout:?}"); + } + if lhs_dim2 == 0 { + candle::bail!("CodebookEncode, empty last dim {lhs_layout:?}") + } + let lhs = match lhs_layout.contiguous_offsets() { + None => candle::bail!("CodebookEncode, lhs has to be contiguous, got {lhs_layout:?}"), + Some((o1, o2)) => { + let slice = lhs_storage.as_slice::<f32>()?; + &slice[o1..o2] + } + }; + let rhs = match rhs_layout.contiguous_offsets() { + None => candle::bail!("CodebookEncode, rhs has to be contiguous, got {rhs_layout:?}"), + Some((o1, o2)) => { + let slice = rhs_storage.as_slice::<f32>()?; + &slice[o1..o2] + } + }; + let dst = (0..lhs_dim1) + .into_par_iter() + .map(|idx1| { + let mut where_min = 0; + let mut min_dist = f32::INFINITY; + let lhs = &lhs[idx1 * lhs_dim2..(idx1 + 1) * lhs_dim2]; + for idx2 in 0..rhs_dim1 { + let rhs = &rhs[idx2 * rhs_dim2..(idx2 + 1) * rhs_dim2]; + let mut dist = 0f32; + for (a, b) in lhs.iter().zip(rhs.iter()) { + dist += (a - b) * (a - b) + } + if dist < min_dist { + min_dist = dist; + where_min = idx2; + } + } + where_min as u32 + }) + .collect(); + let storage = candle::WithDType::to_cpu_storage_owned(dst); + Ok((storage, (lhs_dim1,).into())) + } +} + +// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340 +#[derive(Clone, Debug)] +pub struct EuclideanCodebook { + inited: Tensor, + cluster_size: Tensor, + embed: candle_nn::Embedding, + embed_avg: Tensor, +} + +impl EuclideanCodebook { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let inited = vb.get(1, "inited")?; + let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?; + let e_shape = (cfg.codebook_size, cfg.codebook_dim()); + let embed = vb.get(e_shape, "embed")?; + let embed_avg = vb.get(e_shape, "embed_avg")?; + Ok(Self { + inited, + cluster_size, + embed: candle_nn::Embedding::new(embed, cfg.codebook_dim()), + embed_avg, + }) + } + + pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> { + let quantize = self.embed.forward(embed_ind)?; + Ok(quantize) + } +} + +#[derive(Clone, Debug)] +pub struct VectorQuantization { + codebook: EuclideanCodebook, +} + +impl VectorQuantization { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let codebook = EuclideanCodebook::new(cfg, vb.pp("codebook"))?; + Ok(Self { codebook }) + } + + pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> { + let quantize = self.codebook.decode(embed_ind)?; + let quantize = quantize.transpose(1, 2)?; + Ok(quantize) + } +} + +#[derive(Clone, Debug)] +pub struct ResidualVectorQuantizer { + layers: Vec<VectorQuantization>, +} + +impl ResidualVectorQuantizer { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let vb = &vb.pp("layers"); + let layers = (0..cfg.num_quantizers()) + .map(|i| VectorQuantization::new(cfg, vb.pp(i))) + .collect::<Result<Vec<_>>>()?; + Ok(Self { layers }) + } + + pub fn decode(&self, codes: &Tensor) -> Result<Tensor> { + let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?; + let ncodes = codes.dim(0)?; + if ncodes > self.layers.len() { + candle::bail!( + "codes shape {:?} does not match the number of quantization layers {}", + codes.shape(), + self.layers.len() + ) + } + for (i, layer) in self.layers.iter().take(ncodes).enumerate() { + let quantized = layer.decode(&codes.i(i)?)?; + quantized_out = quantized.broadcast_add(&quantized_out)?; + } + Ok(quantized_out) + } +} + +// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226 +#[derive(Clone, Debug)] +pub struct EncodecLSTM { + layers: Vec<candle_nn::LSTM>, +} + +impl EncodecLSTM { + pub fn new(dim: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> { + let vb = &vb.pp("lstm"); + let mut layers = vec![]; + for layer_idx in 0..cfg.num_lstm_layers { + let config = candle_nn::LSTMConfig { + layer_idx, + ..Default::default() + }; + let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?; + layers.push(lstm) + } + Ok(Self { layers }) + } +} + +impl Module for EncodecLSTM { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + use candle_nn::RNN; + // This is different from the Python transformers version as candle LSTM is batch first. + let xs = xs.t()?; + let residual = &xs; + let mut xs = xs.clone(); + for layer in self.layers.iter() { + let states = layer.seq(&xs)?; + xs = layer.states_to_tensor(&states)?; + } + let xs = (xs + residual)?.t()?; + Ok(xs) + } +} + +#[derive(Clone, Debug)] +pub struct EncodecConvTranspose1d { + conv: ConvTranspose1d, +} + +impl EncodecConvTranspose1d { + fn new( + in_c: usize, + out_c: usize, + k: usize, + stride: usize, + _cfg: &Config, + vb: VarBuilder, + ) -> Result<Self> { + let cfg = candle_nn::ConvTranspose1dConfig { + stride, + ..Default::default() + }; + let conv = conv_transpose1d_weight_norm(in_c, out_c, k, true, cfg, vb.pp("conv"))?; + Ok(Self { conv }) + } +} + +impl Module for EncodecConvTranspose1d { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + xs.apply(&self.conv) + } +} + +#[derive(Clone, Debug)] +pub struct EncodecConv1d { + causal: bool, + conv: Conv1d, + norm: Option<candle_nn::GroupNorm>, + pad_mode: PadMode, +} + +impl EncodecConv1d { + pub fn new( + in_c: usize, + out_c: usize, + kernel_size: usize, + stride: usize, + cfg: &Config, + vb: VarBuilder, + ) -> Result<Self> { + let conv = match cfg.norm_type { + NormType::WeightNorm => conv1d_weight_norm( + in_c, + out_c, + kernel_size, + candle_nn::Conv1dConfig { + padding: 0, + stride, + groups: 1, + dilation: 1, + }, + vb.pp("conv"), + )?, + NormType::None | NormType::TimeGroupNorm => conv1d( + in_c, + out_c, + kernel_size, + candle_nn::Conv1dConfig { + padding: 0, + stride, + groups: 1, + dilation: 1, + }, + vb.pp("conv"), + )?, + }; + let norm = match cfg.norm_type { + NormType::None | NormType::WeightNorm => None, + NormType::TimeGroupNorm => { + let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?; + Some(gn) + } + }; + Ok(Self { + causal: cfg.use_causal_conv, + conv, + norm, + pad_mode: cfg.pad_mode, + }) + } +} + +impl Module for EncodecConv1d { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let (_b, _t, _c) = xs.dims3()?; + let k_size = self.conv.weight().dim(D::Minus1)?; + let conv_cfg = self.conv.config(); + // Effective kernel size with dilations. + let k_size = (k_size - 1) * conv_cfg.dilation + 1; + let padding_total = k_size - conv_cfg.stride; + let extra_padding = + get_extra_padding_for_conv1d(xs, k_size, conv_cfg.stride, padding_total)?; + let xs = if self.causal { + pad1d(xs, padding_total, extra_padding, self.pad_mode)? + } else { + let padding_right = padding_total / 2; + let padding_left = padding_total - padding_right; + pad1d( + xs, + padding_left, + padding_right + extra_padding, + self.pad_mode, + )? + }; + let xs = self.conv.forward(&xs)?; + match &self.norm { + None => Ok(xs), + Some(norm) => xs.apply(norm), + } + } +} + +#[derive(Clone, Debug)] +pub struct EncodecResnetBlock { + block_conv1: EncodecConv1d, + block_conv2: EncodecConv1d, + shortcut: Option<EncodecConv1d>, +} + +impl EncodecResnetBlock { + pub fn new(dim: usize, dilations: &[usize], cfg: &Config, vb: VarBuilder) -> Result<Self> { + let h = dim / cfg.compress; + let mut layer = Layer::new(vb.pp("block")); + if dilations.len() != 2 { + candle::bail!("expected dilations of size 2") + } + // TODO: Apply dilations! + layer.inc(); + let block_conv1 = + EncodecConv1d::new(dim, h, cfg.residual_kernel_size, 1, cfg, layer.next())?; + layer.inc(); + let block_conv2 = EncodecConv1d::new(h, dim, 1, 1, cfg, layer.next())?; + let shortcut = if cfg.use_conv_shortcut { + let conv = EncodecConv1d::new(dim, dim, 1, 1, cfg, vb.pp("shortcut"))?; + Some(conv) + } else { + None + }; + Ok(Self { + block_conv1, + block_conv2, + shortcut, + }) + } +} + +impl Module for EncodecResnetBlock { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let residual = xs.clone(); + let xs = xs.elu(1.)?; + let xs = self.block_conv1.forward(&xs)?; + let xs = xs.elu(1.)?; + let xs = self.block_conv2.forward(&xs)?; + let xs = match &self.shortcut { + None => (xs + residual)?, + Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?, + }; + Ok(xs) + } +} + +struct Layer<'a> { + vb: VarBuilder<'a>, + cnt: usize, +} + +impl<'a> Layer<'a> { + fn new(vb: VarBuilder<'a>) -> Self { + Self { vb, cnt: 0 } + } + + fn inc(&mut self) { + self.cnt += 1; + } + + fn next(&mut self) -> VarBuilder { + let vb = self.vb.pp(&self.cnt.to_string()); + self.cnt += 1; + vb + } +} + +#[derive(Clone, Debug)] +pub struct Encoder { + init_conv: EncodecConv1d, + sampling_layers: Vec<(Vec<EncodecResnetBlock>, EncodecConv1d)>, + final_lstm: EncodecLSTM, + final_conv: EncodecConv1d, +} + +impl Encoder { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let mut layer = Layer::new(vb.pp("layers")); + let init_conv = EncodecConv1d::new( + cfg.audio_channels, + cfg.num_filters, + cfg.kernel_size, + 1, + cfg, + layer.next(), + )?; + let mut sampling_layers = vec![]; + let mut scaling = 1; + for &ratio in cfg.upsampling_ratios.iter().rev() { + let current_scale = scaling * cfg.num_filters; + let mut resnets = vec![]; + for j in 0..(cfg.num_residual_layers as u32) { + let resnet = EncodecResnetBlock::new( + current_scale, + &[cfg.dilation_growth_rate.pow(j), 1], + cfg, + layer.next(), + )?; + resnets.push(resnet) + } + layer.inc(); // ELU + let conv1d = EncodecConv1d::new( + current_scale, + current_scale * 2, + ratio * 2, + ratio, + cfg, + layer.next(), + )?; + sampling_layers.push((resnets, conv1d)); + scaling *= 2; + } + let final_lstm = EncodecLSTM::new(cfg.num_filters * scaling, cfg, layer.next())?; + layer.inc(); // ELU + let final_conv = EncodecConv1d::new( + cfg.num_filters * scaling, + cfg.hidden_size, + cfg.last_kernel_size, + 1, + cfg, + layer.next(), + )?; + Ok(Self { + init_conv, + sampling_layers, + final_conv, + final_lstm, + }) + } +} + +impl Module for Encoder { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = xs.apply(&self.init_conv)?; + for (resnets, conv) in self.sampling_layers.iter() { + for resnet in resnets.iter() { + xs = xs.apply(resnet)?; + } + xs = xs.elu(1.0)?.apply(conv)?; + } + xs.apply(&self.final_lstm)? + .elu(1.0)? + .apply(&self.final_conv) + } +} + +#[derive(Clone, Debug)] +pub struct Decoder { + init_conv: EncodecConv1d, + init_lstm: EncodecLSTM, + sampling_layers: Vec<(EncodecConvTranspose1d, Vec<EncodecResnetBlock>)>, + final_conv: EncodecConv1d, +} + +impl Decoder { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let mut layer = Layer::new(vb.pp("layers")); + let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32); + let init_conv = EncodecConv1d::new( + cfg.hidden_size, + cfg.num_filters * scaling, + cfg.last_kernel_size, + 1, + cfg, + layer.next(), + )?; + let init_lstm = EncodecLSTM::new(cfg.num_filters * scaling, cfg, layer.next())?; + let mut sampling_layers = vec![]; + for &ratio in cfg.upsampling_ratios.iter() { + let current_scale = scaling * cfg.num_filters; + layer.inc(); // ELU + let conv1d = EncodecConvTranspose1d::new( + current_scale, + current_scale / 2, + ratio * 2, + ratio, + cfg, + layer.next(), + )?; + let mut resnets = vec![]; + for j in 0..(cfg.num_residual_layers as u32) { + let resnet = EncodecResnetBlock::new( + current_scale / 2, + &[cfg.dilation_growth_rate.pow(j), 1], + cfg, + layer.next(), + )?; + resnets.push(resnet) + } + sampling_layers.push((conv1d, resnets)); + scaling /= 2; + } + layer.inc(); // ELU + let final_conv = EncodecConv1d::new( + cfg.num_filters, + cfg.audio_channels, + cfg.last_kernel_size, + 1, + cfg, + layer.next(), + )?; + Ok(Self { + init_conv, + init_lstm, + sampling_layers, + final_conv, + }) + } +} + +impl Module for Decoder { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?; + for (conv, resnets) in self.sampling_layers.iter() { + xs = xs.elu(1.)?.apply(conv)?; + for resnet in resnets.iter() { + xs = xs.apply(resnet)? + } + } + xs.elu(1.)?.apply(&self.final_conv) + } +} + +#[derive(Debug)] +pub struct Model { + encoder: Encoder, + decoder: Decoder, + quantizer: ResidualVectorQuantizer, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let encoder = Encoder::new(cfg, vb.pp("encoder"))?; + let decoder = Decoder::new(cfg, vb.pp("decoder"))?; + let quantizer = ResidualVectorQuantizer::new(cfg, vb.pp("quantizer"))?; + Ok(Self { + encoder, + decoder, + quantizer, + }) + } + + pub fn forward(&self, _xs: &Tensor) -> Result<Tensor> { + todo!() + } + + pub fn encode(&self, _xs: &Tensor) -> Result<Tensor> { + todo!() + } + + pub fn decode(&self, codes: &Tensor) -> Result<Tensor> { + let (_b_sz, _codebooks, _seqlen) = codes.dims3()?; + let codes = codes.transpose(0, 1)?; + let embeddings = self.quantizer.decode(&codes)?; + let outputs = self.decoder.forward(&embeddings)?; + Ok(outputs) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 96627683..3624c8f1 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -8,6 +8,7 @@ pub mod convnext; pub mod dinov2; pub mod distilbert; pub mod efficientnet; +pub mod encodec; pub mod falcon; pub mod gemma; pub mod jina_bert; |