diff options
Diffstat (limited to 'candle-examples/examples/musicgen/main.rs')
-rw-r--r-- | candle-examples/examples/musicgen/main.rs | 27 |
1 files changed, 24 insertions, 3 deletions
diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs index 8dcef6d2..3794c22d 100644 --- a/candle-examples/examples/musicgen/main.rs +++ b/candle-examples/examples/musicgen/main.rs @@ -18,7 +18,7 @@ mod t5_model; use musicgen_model::{GenConfig, MusicgenForConditionalGeneration}; use anyhow::{Error as E, Result}; -use candle::DType; +use candle::{DType, Tensor}; use candle_nn::VarBuilder; use clap::Parser; use hf_hub::{api::sync::Api, Repo, RepoType}; @@ -39,6 +39,12 @@ struct Args { /// The tokenizer config. #[arg(long)] tokenizer: Option<String>, + + #[arg( + long, + default_value = "90s rock song with loud guitars and heavy drums" + )] + prompt: String, } fn main() -> Result<()> { @@ -53,7 +59,10 @@ fn main() -> Result<()> { .get("tokenizer.json")?, }; let mut tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; - let _tokenizer = tokenizer.with_padding(None).with_truncation(None); + let tokenizer = tokenizer + .with_padding(None) + .with_truncation(None) + .map_err(E::msg)?; let model = match args.model { Some(model) => std::path::PathBuf::from(model), @@ -69,6 +78,18 @@ fn main() -> Result<()> { let model = model.deserialize()?; let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device); let config = GenConfig::small(); - let _model = MusicgenForConditionalGeneration::load(vb, config)?; + let model = MusicgenForConditionalGeneration::load(vb, config)?; + + let tokens = tokenizer + .encode(args.prompt.as_str(), true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + println!("tokens: {tokens:?}"); + let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; + println!("{tokens:?}"); + let embeds = model.text_encoder.forward(&tokens)?; + println!("{embeds}"); + Ok(()) } |