diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-03 19:27:48 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-03 18:27:48 +0100 |
commit | 26cd266e6569b0640947d4cacb4d6b9c27c01623 (patch) | |
tree | cb7fa82b7bb5978d69506d00d10cd35b4211cd40 | |
parent | bbec527bb966b5050a9f8a3fe1382ea929e39d41 (diff) | |
download | candle-26cd266e6569b0640947d4cacb4d6b9c27c01623.tar.gz candle-26cd266e6569b0640947d4cacb4d6b9c27c01623.tar.bz2 candle-26cd266e6569b0640947d4cacb4d6b9c27c01623.zip |
Musicgen text embeddings. (#726)
* Musicgen text embeddings.
* Bugfix for layer norm.
* Proper position bias.
* Expose the weights.
-rw-r--r-- | candle-examples/examples/musicgen/main.rs | 27 | ||||
-rw-r--r-- | candle-examples/examples/musicgen/musicgen_model.rs | 6 | ||||
-rw-r--r-- | candle-examples/examples/musicgen/t5_model.rs | 59 | ||||
-rw-r--r-- | candle-nn/src/linear.rs | 8 |
4 files changed, 85 insertions, 15 deletions
diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs index 8dcef6d2..3794c22d 100644 --- a/candle-examples/examples/musicgen/main.rs +++ b/candle-examples/examples/musicgen/main.rs @@ -18,7 +18,7 @@ mod t5_model; use musicgen_model::{GenConfig, MusicgenForConditionalGeneration}; use anyhow::{Error as E, Result}; -use candle::DType; +use candle::{DType, Tensor}; use candle_nn::VarBuilder; use clap::Parser; use hf_hub::{api::sync::Api, Repo, RepoType}; @@ -39,6 +39,12 @@ struct Args { /// The tokenizer config. #[arg(long)] tokenizer: Option<String>, + + #[arg( + long, + default_value = "90s rock song with loud guitars and heavy drums" + )] + prompt: String, } fn main() -> Result<()> { @@ -53,7 +59,10 @@ fn main() -> Result<()> { .get("tokenizer.json")?, }; let mut tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; - let _tokenizer = tokenizer.with_padding(None).with_truncation(None); + let tokenizer = tokenizer + .with_padding(None) + .with_truncation(None) + .map_err(E::msg)?; let model = match args.model { Some(model) => std::path::PathBuf::from(model), @@ -69,6 +78,18 @@ fn main() -> Result<()> { let model = model.deserialize()?; let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device); let config = GenConfig::small(); - let _model = MusicgenForConditionalGeneration::load(vb, config)?; + let model = MusicgenForConditionalGeneration::load(vb, config)?; + + let tokens = tokenizer + .encode(args.prompt.as_str(), true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + println!("tokens: {tokens:?}"); + let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; + println!("{tokens:?}"); + let embeds = model.text_encoder.forward(&tokens)?; + println!("{embeds}"); + Ok(()) } diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs index 751e0226..7e272fd7 100644 --- a/candle-examples/examples/musicgen/musicgen_model.rs +++ b/candle-examples/examples/musicgen/musicgen_model.rs @@ -370,9 +370,9 @@ impl MusicgenForCausalLM { #[derive(Debug)] pub struct MusicgenForConditionalGeneration { - text_encoder: crate::t5_model::T5EncoderModel, - audio_encoder: crate::encodec_model::EncodecModel, - decoder: MusicgenForCausalLM, + pub text_encoder: crate::t5_model::T5EncoderModel, + pub audio_encoder: crate::encodec_model::EncodecModel, + pub decoder: MusicgenForCausalLM, cfg: GenConfig, } diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs index 33b11b95..607b5c93 100644 --- a/candle-examples/examples/musicgen/t5_model.rs +++ b/candle-examples/examples/musicgen/t5_model.rs @@ -96,10 +96,9 @@ impl T5LayerNorm { fn forward(&self, xs: &Tensor) -> Result<Tensor> { let dtype = xs.dtype(); let xs_f32 = xs.to_dtype(DType::F32)?; - let xs2_f32 = (&xs_f32 * &xs_f32)?; - let sum_xs2_f32 = xs2_f32.sum_keepdim(D::Minus1)?; - let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?; - let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?; + // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?; + let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?; let xs = xs.to_dtype(dtype)?; let xs = xs.broadcast_mul(&self.weight)?; Ok(xs) @@ -167,6 +166,9 @@ struct T5Attention { n_heads: usize, d_kv: usize, relative_attention_bias: Option<Embedding>, + relative_attention_num_buckets: usize, + relative_attention_max_distance: usize, + inner_dim: usize, } impl T5Attention { @@ -194,6 +196,9 @@ impl T5Attention { n_heads: cfg.num_heads, d_kv: cfg.d_kv, relative_attention_bias, + relative_attention_num_buckets: cfg.relative_attention_num_buckets, + relative_attention_max_distance: cfg.relative_attention_max_distance, + inner_dim, }) } @@ -206,17 +211,53 @@ impl T5Attention { let v = self.v.forward(xs)?; let q = q .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let k = k .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let v = v .reshape((b_sz, seq_len, self.n_heads, self.d_kv))? - .transpose(1, 2)?; + .transpose(1, 2)? + .contiguous()?; let scores = q.matmul(&k.t()?)?; - // TODO: position_bias_masked + + let scores = match &self.relative_attention_bias { + None => scores, + Some(relative_attention_bias) => { + let query_length = seq_len; + let key_length = seq_len; + // This only handles the bidirectional case. + let num_buckets = self.relative_attention_num_buckets / 2; + let relative_position = (0..query_length as u32) + .map(|i| { + (0..key_length as u32) + .map(|j| { + if i < j { + j - i + num_buckets as u32 + } else { + i - j + } + }) + .collect::<Vec<u32>>() + }) + .collect::<Vec<Vec<_>>>(); + let relative_buckets = Tensor::new(relative_position, q.device())?; + let position_bias = relative_attention_bias + .forward(&relative_buckets)? + .permute((2, 0, 1))? + .unsqueeze(0)?; + (scores + position_bias)? + // TODO: position_bias_masked? + } + }; + let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?; let attn_output = attn_weights.matmul(&v)?; + let attn_output = attn_output + .transpose(1, 2)? + .reshape((b_sz, seq_len, self.inner_dim))?; let attn_output = self.o.forward(&attn_output)?; Ok(attn_output) } @@ -324,7 +365,7 @@ impl T5Stack { fn forward(&self, input_ids: &Tensor) -> Result<Tensor> { let input_embeds = self.shared.as_ref().forward(input_ids)?; - let (_b_sz, _seq_len) = input_embeds.dims2()?; + let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?); let mut hidden_states = input_embeds; for block in self.block.iter() { diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs index 14250ed2..7028f68c 100644 --- a/candle-nn/src/linear.rs +++ b/candle-nn/src/linear.rs @@ -29,6 +29,14 @@ impl Linear { pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self { Self { weight, bias } } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn bias(&self) -> Option<&Tensor> { + self.bias.as_ref() + } } impl super::Module for Linear { |