summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/musicgen/main.rs27
-rw-r--r--candle-examples/examples/musicgen/musicgen_model.rs6
-rw-r--r--candle-examples/examples/musicgen/t5_model.rs59
-rw-r--r--candle-nn/src/linear.rs8
4 files changed, 85 insertions, 15 deletions
diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs
index 8dcef6d2..3794c22d 100644
--- a/candle-examples/examples/musicgen/main.rs
+++ b/candle-examples/examples/musicgen/main.rs
@@ -18,7 +18,7 @@ mod t5_model;
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
use anyhow::{Error as E, Result};
-use candle::DType;
+use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
@@ -39,6 +39,12 @@ struct Args {
/// The tokenizer config.
#[arg(long)]
tokenizer: Option<String>,
+
+ #[arg(
+ long,
+ default_value = "90s rock song with loud guitars and heavy drums"
+ )]
+ prompt: String,
}
fn main() -> Result<()> {
@@ -53,7 +59,10 @@ fn main() -> Result<()> {
.get("tokenizer.json")?,
};
let mut tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
- let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
+ let tokenizer = tokenizer
+ .with_padding(None)
+ .with_truncation(None)
+ .map_err(E::msg)?;
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
@@ -69,6 +78,18 @@ fn main() -> Result<()> {
let model = model.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
let config = GenConfig::small();
- let _model = MusicgenForConditionalGeneration::load(vb, config)?;
+ let model = MusicgenForConditionalGeneration::load(vb, config)?;
+
+ let tokens = tokenizer
+ .encode(args.prompt.as_str(), true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ println!("tokens: {tokens:?}");
+ let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
+ println!("{tokens:?}");
+ let embeds = model.text_encoder.forward(&tokens)?;
+ println!("{embeds}");
+
Ok(())
}
diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs
index 751e0226..7e272fd7 100644
--- a/candle-examples/examples/musicgen/musicgen_model.rs
+++ b/candle-examples/examples/musicgen/musicgen_model.rs
@@ -370,9 +370,9 @@ impl MusicgenForCausalLM {
#[derive(Debug)]
pub struct MusicgenForConditionalGeneration {
- text_encoder: crate::t5_model::T5EncoderModel,
- audio_encoder: crate::encodec_model::EncodecModel,
- decoder: MusicgenForCausalLM,
+ pub text_encoder: crate::t5_model::T5EncoderModel,
+ pub audio_encoder: crate::encodec_model::EncodecModel,
+ pub decoder: MusicgenForCausalLM,
cfg: GenConfig,
}
diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs
index 33b11b95..607b5c93 100644
--- a/candle-examples/examples/musicgen/t5_model.rs
+++ b/candle-examples/examples/musicgen/t5_model.rs
@@ -96,10 +96,9 @@ impl T5LayerNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let dtype = xs.dtype();
let xs_f32 = xs.to_dtype(DType::F32)?;
- let xs2_f32 = (&xs_f32 * &xs_f32)?;
- let sum_xs2_f32 = xs2_f32.sum_keepdim(D::Minus1)?;
- let variance = xs2_f32.broadcast_div(&sum_xs2_f32)?;
- let xs = (xs / (variance + self.variance_epsilon)?.sqrt()?)?;
+ // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
+ let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
let xs = xs.to_dtype(dtype)?;
let xs = xs.broadcast_mul(&self.weight)?;
Ok(xs)
@@ -167,6 +166,9 @@ struct T5Attention {
n_heads: usize,
d_kv: usize,
relative_attention_bias: Option<Embedding>,
+ relative_attention_num_buckets: usize,
+ relative_attention_max_distance: usize,
+ inner_dim: usize,
}
impl T5Attention {
@@ -194,6 +196,9 @@ impl T5Attention {
n_heads: cfg.num_heads,
d_kv: cfg.d_kv,
relative_attention_bias,
+ relative_attention_num_buckets: cfg.relative_attention_num_buckets,
+ relative_attention_max_distance: cfg.relative_attention_max_distance,
+ inner_dim,
})
}
@@ -206,17 +211,53 @@ impl T5Attention {
let v = self.v.forward(xs)?;
let q = q
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
- .transpose(1, 2)?;
+ .transpose(1, 2)?
+ .contiguous()?;
let k = k
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
- .transpose(1, 2)?;
+ .transpose(1, 2)?
+ .contiguous()?;
let v = v
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
- .transpose(1, 2)?;
+ .transpose(1, 2)?
+ .contiguous()?;
let scores = q.matmul(&k.t()?)?;
- // TODO: position_bias_masked
+
+ let scores = match &self.relative_attention_bias {
+ None => scores,
+ Some(relative_attention_bias) => {
+ let query_length = seq_len;
+ let key_length = seq_len;
+ // This only handles the bidirectional case.
+ let num_buckets = self.relative_attention_num_buckets / 2;
+ let relative_position = (0..query_length as u32)
+ .map(|i| {
+ (0..key_length as u32)
+ .map(|j| {
+ if i < j {
+ j - i + num_buckets as u32
+ } else {
+ i - j
+ }
+ })
+ .collect::<Vec<u32>>()
+ })
+ .collect::<Vec<Vec<_>>>();
+ let relative_buckets = Tensor::new(relative_position, q.device())?;
+ let position_bias = relative_attention_bias
+ .forward(&relative_buckets)?
+ .permute((2, 0, 1))?
+ .unsqueeze(0)?;
+ (scores + position_bias)?
+ // TODO: position_bias_masked?
+ }
+ };
+
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
let attn_output = attn_weights.matmul(&v)?;
+ let attn_output = attn_output
+ .transpose(1, 2)?
+ .reshape((b_sz, seq_len, self.inner_dim))?;
let attn_output = self.o.forward(&attn_output)?;
Ok(attn_output)
}
@@ -324,7 +365,7 @@ impl T5Stack {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let input_embeds = self.shared.as_ref().forward(input_ids)?;
- let (_b_sz, _seq_len) = input_embeds.dims2()?;
+ let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
let mut hidden_states = input_embeds;
for block in self.block.iter() {
diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs
index 14250ed2..7028f68c 100644
--- a/candle-nn/src/linear.rs
+++ b/candle-nn/src/linear.rs
@@ -29,6 +29,14 @@ impl Linear {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
Self { weight, bias }
}
+
+ pub fn weight(&self) -> &Tensor {
+ &self.weight
+ }
+
+ pub fn bias(&self) -> Option<&Tensor> {
+ self.bias.as_ref()
+ }
}
impl super::Module for Linear {