summaryrefslogtreecommitdiff
path: root/candle-examples/examples/musicgen/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/musicgen/main.rs')
-rw-r--r--candle-examples/examples/musicgen/main.rs27
1 files changed, 22 insertions, 5 deletions
diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs
index 7598280e..8dcef6d2 100644
--- a/candle-examples/examples/musicgen/main.rs
+++ b/candle-examples/examples/musicgen/main.rs
@@ -16,11 +16,12 @@ mod nn;
mod t5_model;
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
-use nn::VarBuilder;
use anyhow::{Error as E, Result};
use candle::DType;
+use candle_nn::VarBuilder;
use clap::Parser;
+use hf_hub::{api::sync::Api, Repo, RepoType};
const DTYPE: DType = DType::F32;
@@ -33,11 +34,11 @@ struct Args {
/// The model weight file, in safetensor format.
#[arg(long)]
- model: String,
+ model: Option<String>,
/// The tokenizer config.
#[arg(long)]
- tokenizer: String,
+ tokenizer: Option<String>,
}
fn main() -> Result<()> {
@@ -45,10 +46,26 @@ fn main() -> Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
- let mut tokenizer = Tokenizer::from_file(args.tokenizer).map_err(E::msg)?;
+ let tokenizer = match args.tokenizer {
+ Some(tokenizer) => std::path::PathBuf::from(tokenizer),
+ None => Api::new()?
+ .model("facebook/musicgen-small".to_string())
+ .get("tokenizer.json")?,
+ };
+ let mut tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
- let model = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
+ let model = match args.model {
+ Some(model) => std::path::PathBuf::from(model),
+ None => Api::new()?
+ .repo(Repo::with_revision(
+ "facebook/musicgen-small".to_string(),
+ RepoType::Model,
+ "refs/pr/13".to_string(),
+ ))
+ .get("model.safetensors")?,
+ };
+ let model = unsafe { candle::safetensors::MmapedFile::new(model)? };
let model = model.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
let config = GenConfig::small();