diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-03 19:27:48 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-03 18:27:48 +0100 |
commit | 26cd266e6569b0640947d4cacb4d6b9c27c01623 (patch) | |
tree | cb7fa82b7bb5978d69506d00d10cd35b4211cd40 /candle-examples/examples/musicgen/main.rs | |
parent | bbec527bb966b5050a9f8a3fe1382ea929e39d41 (diff) | |
download | candle-26cd266e6569b0640947d4cacb4d6b9c27c01623.tar.gz candle-26cd266e6569b0640947d4cacb4d6b9c27c01623.tar.bz2 candle-26cd266e6569b0640947d4cacb4d6b9c27c01623.zip |
Musicgen text embeddings. (#726)
* Musicgen text embeddings.
* Bugfix for layer norm.
* Proper position bias.
* Expose the weights.
Diffstat (limited to 'candle-examples/examples/musicgen/main.rs')
-rw-r--r-- | candle-examples/examples/musicgen/main.rs | 27 |
1 files changed, 24 insertions, 3 deletions
diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs index 8dcef6d2..3794c22d 100644 --- a/candle-examples/examples/musicgen/main.rs +++ b/candle-examples/examples/musicgen/main.rs @@ -18,7 +18,7 @@ mod t5_model; use musicgen_model::{GenConfig, MusicgenForConditionalGeneration}; use anyhow::{Error as E, Result}; -use candle::DType; +use candle::{DType, Tensor}; use candle_nn::VarBuilder; use clap::Parser; use hf_hub::{api::sync::Api, Repo, RepoType}; @@ -39,6 +39,12 @@ struct Args { /// The tokenizer config. #[arg(long)] tokenizer: Option<String>, + + #[arg( + long, + default_value = "90s rock song with loud guitars and heavy drums" + )] + prompt: String, } fn main() -> Result<()> { @@ -53,7 +59,10 @@ fn main() -> Result<()> { .get("tokenizer.json")?, }; let mut tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; - let _tokenizer = tokenizer.with_padding(None).with_truncation(None); + let tokenizer = tokenizer + .with_padding(None) + .with_truncation(None) + .map_err(E::msg)?; let model = match args.model { Some(model) => std::path::PathBuf::from(model), @@ -69,6 +78,18 @@ fn main() -> Result<()> { let model = model.deserialize()?; let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device); let config = GenConfig::small(); - let _model = MusicgenForConditionalGeneration::load(vb, config)?; + let model = MusicgenForConditionalGeneration::load(vb, config)?; + + let tokens = tokenizer + .encode(args.prompt.as_str(), true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + println!("tokens: {tokens:?}"); + let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; + println!("{tokens:?}"); + let embeds = model.text_encoder.forward(&tokens)?; + println!("{embeds}"); + Ok(()) } |