diff options
Diffstat (limited to 'candle-examples/examples/musicgen/main.rs')
-rw-r--r-- | candle-examples/examples/musicgen/main.rs | 27 |
1 files changed, 22 insertions, 5 deletions
diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs index 7598280e..8dcef6d2 100644 --- a/candle-examples/examples/musicgen/main.rs +++ b/candle-examples/examples/musicgen/main.rs @@ -16,11 +16,12 @@ mod nn; mod t5_model; use musicgen_model::{GenConfig, MusicgenForConditionalGeneration}; -use nn::VarBuilder; use anyhow::{Error as E, Result}; use candle::DType; +use candle_nn::VarBuilder; use clap::Parser; +use hf_hub::{api::sync::Api, Repo, RepoType}; const DTYPE: DType = DType::F32; @@ -33,11 +34,11 @@ struct Args { /// The model weight file, in safetensor format. #[arg(long)] - model: String, + model: Option<String>, /// The tokenizer config. #[arg(long)] - tokenizer: String, + tokenizer: Option<String>, } fn main() -> Result<()> { @@ -45,10 +46,26 @@ fn main() -> Result<()> { let args = Args::parse(); let device = candle_examples::device(args.cpu)?; - let mut tokenizer = Tokenizer::from_file(args.tokenizer).map_err(E::msg)?; + let tokenizer = match args.tokenizer { + Some(tokenizer) => std::path::PathBuf::from(tokenizer), + None => Api::new()? + .model("facebook/musicgen-small".to_string()) + .get("tokenizer.json")?, + }; + let mut tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; let _tokenizer = tokenizer.with_padding(None).with_truncation(None); - let model = unsafe { candle::safetensors::MmapedFile::new(args.model)? }; + let model = match args.model { + Some(model) => std::path::PathBuf::from(model), + None => Api::new()? + .repo(Repo::with_revision( + "facebook/musicgen-small".to_string(), + RepoType::Model, + "refs/pr/13".to_string(), + )) + .get("model.safetensors")?, + }; + let model = unsafe { candle::safetensors::MmapedFile::new(model)? }; let model = model.deserialize()?; let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device); let config = GenConfig::small(); |