diff options
Diffstat (limited to 'candle-examples/examples/musicgen/musicgen_model.rs')
-rw-r--r-- | candle-examples/examples/musicgen/musicgen_model.rs | 39 |
1 files changed, 32 insertions, 7 deletions
diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs index c6b52fde..03e96614 100644 --- a/candle-examples/examples/musicgen/musicgen_model.rs +++ b/candle-examples/examples/musicgen/musicgen_model.rs @@ -1,10 +1,9 @@ -use crate::encodec_model; use candle::{DType, Device, Result, Tensor, D}; use candle_nn::{ embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module, VarBuilder, }; -use candle_transformers::models::t5; +use candle_transformers::models::{encodec, t5}; // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83 #[derive(Debug, Clone, PartialEq)] @@ -372,7 +371,7 @@ impl MusicgenForCausalLM { #[derive(Debug)] pub struct MusicgenForConditionalGeneration { pub text_encoder: t5::T5EncoderModel, - pub audio_encoder: crate::encodec_model::EncodecModel, + pub audio_encoder: encodec::Model, pub decoder: MusicgenForCausalLM, cfg: GenConfig, } @@ -381,15 +380,42 @@ pub struct MusicgenForConditionalGeneration { pub struct GenConfig { musicgen: Config, t5: t5::Config, - encodec: crate::encodec_model::Config, + encodec: encodec::Config, } impl GenConfig { pub fn small() -> Self { + // https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6 + let encodec = encodec::Config { + audio_channels: 1, + chunk_length_s: None, + codebook_dim: Some(128), + codebook_size: 2048, + compress: 2, + dilation_growth_rate: 2, + hidden_size: 128, + kernel_size: 7, + last_kernel_size: 7, + norm_type: encodec::NormType::WeightNorm, + normalize: false, + num_filters: 64, + num_lstm_layers: 2, + num_residual_layers: 1, + overlap: None, + // This should be Reflect and not Replicate but Reflect does not work yet. + pad_mode: encodec::PadMode::Replicate, + residual_kernel_size: 3, + sampling_rate: 32_000, + target_bandwidths: vec![2.2], + trim_right_ratio: 1.0, + upsampling_ratios: vec![8, 5, 4, 4], + use_causal_conv: false, + use_conv_shortcut: false, + }; Self { musicgen: Config::musicgen_small(), t5: t5::Config::musicgen_small(), - encodec: encodec_model::Config::musicgen_small(), + encodec, } } } @@ -401,8 +427,7 @@ impl MusicgenForConditionalGeneration { pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> { let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?; - let audio_encoder = - encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?; + let audio_encoder = encodec::Model::new(&cfg.encodec, vb.pp("audio_encoder"))?; let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?; Ok(Self { text_encoder, |