summaryrefslogtreecommitdiff
path: root/candle-examples/examples/musicgen/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/musicgen/main.rs')
-rw-r--r--candle-examples/examples/musicgen/main.rs17
1 files changed, 2 insertions, 15 deletions
diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs
index 90b464c3..3e136e90 100644
--- a/candle-examples/examples/musicgen/main.rs
+++ b/candle-examples/examples/musicgen/main.rs
@@ -16,7 +16,7 @@ use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
use nn::VarBuilder;
use anyhow::{Error as E, Result};
-use candle::{DType, Device};
+use candle::DType;
use clap::Parser;
const DTYPE: DType = DType::F32;
@@ -41,20 +41,7 @@ fn main() -> Result<()> {
use tokenizers::Tokenizer;
let args = Args::parse();
- #[cfg(feature = "cuda")]
- let default_device = Device::new_cuda(0)?;
-
- #[cfg(not(feature = "cuda"))]
- let default_device = {
- println!("Running on CPU, to run on GPU, run this example with `--features cuda`");
- Device::Cpu
- };
- let device = if args.cpu {
- Device::Cpu
- } else {
- default_device
- };
-
+ let device = candle_examples::device(args.cpu)?;
let mut tokenizer = Tokenizer::from_file(args.tokenizer).map_err(E::msg)?;
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);