summaryrefslogtreecommitdiff
path: root/candle-examples/examples/musicgen/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/musicgen/main.rs')
-rw-r--r--candle-examples/examples/musicgen/main.rs27
1 files changed, 24 insertions, 3 deletions
diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs
index 8dcef6d2..3794c22d 100644
--- a/candle-examples/examples/musicgen/main.rs
+++ b/candle-examples/examples/musicgen/main.rs
@@ -18,7 +18,7 @@ mod t5_model;
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
use anyhow::{Error as E, Result};
-use candle::DType;
+use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
@@ -39,6 +39,12 @@ struct Args {
/// The tokenizer config.
#[arg(long)]
tokenizer: Option<String>,
+
+ #[arg(
+ long,
+ default_value = "90s rock song with loud guitars and heavy drums"
+ )]
+ prompt: String,
}
fn main() -> Result<()> {
@@ -53,7 +59,10 @@ fn main() -> Result<()> {
.get("tokenizer.json")?,
};
let mut tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
- let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
+ let tokenizer = tokenizer
+ .with_padding(None)
+ .with_truncation(None)
+ .map_err(E::msg)?;
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
@@ -69,6 +78,18 @@ fn main() -> Result<()> {
let model = model.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
let config = GenConfig::small();
- let _model = MusicgenForConditionalGeneration::load(vb, config)?;
+ let model = MusicgenForConditionalGeneration::load(vb, config)?;
+
+ let tokens = tokenizer
+ .encode(args.prompt.as_str(), true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ println!("tokens: {tokens:?}");
+ let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
+ println!("{tokens:?}");
+ let embeds = model.text_encoder.forward(&tokens)?;
+ println!("{embeds}");
+
Ok(())
}