summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/device.rs8
-rw-r--r--candle-core/src/utils.rs7
-rw-r--r--candle-examples/examples/bert/main.rs15
-rw-r--r--candle-examples/examples/falcon/main.rs15
-rw-r--r--candle-examples/examples/llama/main.rs15
-rw-r--r--candle-examples/examples/musicgen/main.rs17
-rw-r--r--candle-examples/examples/whisper/main.rs16
-rw-r--r--candle-examples/src/lib.rs12
8 files changed, 33 insertions, 72 deletions
diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs
index ca408529..53e2de43 100644
--- a/candle-core/src/device.rs
+++ b/candle-core/src/device.rs
@@ -109,6 +109,14 @@ impl Device {
}
}
+ pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
+ if crate::utils::cuda_is_available() {
+ Self::new_cuda(ordinal)
+ } else {
+ Ok(Self::Cpu)
+ }
+ }
+
pub(crate) fn rand_uniform(
&self,
shape: &Shape,
diff --git a/candle-core/src/utils.rs b/candle-core/src/utils.rs
index b5621e56..895c97e1 100644
--- a/candle-core/src/utils.rs
+++ b/candle-core/src/utils.rs
@@ -17,3 +17,10 @@ pub fn has_mkl() -> bool {
#[cfg(not(feature = "mkl"))]
return false;
}
+
+pub fn cuda_is_available() -> bool {
+ #[cfg(feature = "cuda")]
+ return true;
+ #[cfg(not(feature = "cuda"))]
+ return false;
+}
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index aae8bc50..dca6721b 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -495,20 +495,7 @@ struct Args {
impl Args {
fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
- #[cfg(feature = "cuda")]
- let default_device = Device::new_cuda(0)?;
-
- #[cfg(not(feature = "cuda"))]
- let default_device = {
- println!("Running on CPU, to run on GPU, run this example with `--features cuda`");
- Device::Cpu
- };
-
- let device = if self.cpu {
- Device::Cpu
- } else {
- default_device
- };
+ let device = candle_examples::device(self.cpu)?;
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
let default_revision = "refs/pr/21".to_string();
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs
index 7e20c7d2..7d5eaa52 100644
--- a/candle-examples/examples/falcon/main.rs
+++ b/candle-examples/examples/falcon/main.rs
@@ -120,20 +120,7 @@ struct Args {
fn main() -> Result<()> {
let args = Args::parse();
- #[cfg(feature = "cuda")]
- let default_device = Device::new_cuda(0)?;
-
- #[cfg(not(feature = "cuda"))]
- let default_device = {
- println!("Running on CPU, to run on GPU, run this example with `--features cuda`");
- Device::Cpu
- };
- let device = if args.cpu {
- Device::Cpu
- } else {
- default_device
- };
-
+ let device = candle_examples::device(args.cpu)?;
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = Repo::with_revision(args.model_id, RepoType::Model, args.revision);
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index 203b4606..aa02299d 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -134,20 +134,7 @@ fn main() -> Result<()> {
let args = Args::parse();
- #[cfg(feature = "cuda")]
- let default_device = Device::new_cuda(0)?;
-
- #[cfg(not(feature = "cuda"))]
- let default_device = {
- println!("Running on CPU, to run on GPU, run this example with `--features cuda`");
- Device::Cpu
- };
-
- let device = if args.cpu {
- Device::Cpu
- } else {
- default_device
- };
+ let device = candle_examples::device(args.cpu)?;
let config = Config::config_7b();
let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs
index 90b464c3..3e136e90 100644
--- a/candle-examples/examples/musicgen/main.rs
+++ b/candle-examples/examples/musicgen/main.rs
@@ -16,7 +16,7 @@ use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
use nn::VarBuilder;
use anyhow::{Error as E, Result};
-use candle::{DType, Device};
+use candle::DType;
use clap::Parser;
const DTYPE: DType = DType::F32;
@@ -41,20 +41,7 @@ fn main() -> Result<()> {
use tokenizers::Tokenizer;
let args = Args::parse();
- #[cfg(feature = "cuda")]
- let default_device = Device::new_cuda(0)?;
-
- #[cfg(not(feature = "cuda"))]
- let default_device = {
- println!("Running on CPU, to run on GPU, run this example with `--features cuda`");
- Device::Cpu
- };
- let device = if args.cpu {
- Device::Cpu
- } else {
- default_device
- };
-
+ let device = candle_examples::device(args.cpu)?;
let mut tokenizer = Tokenizer::from_file(args.tokenizer).map_err(E::msg)?;
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index 09ef4593..d01fb605 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -257,21 +257,7 @@ struct Args {
fn main() -> Result<()> {
let args = Args::parse();
-
- #[cfg(feature = "cuda")]
- let default_device = Device::new_cuda(0)?;
-
- #[cfg(not(feature = "cuda"))]
- let default_device = {
- println!("Running on CPU, to run on GPU, run this example with `--features cuda`");
- Device::Cpu
- };
-
- let device = if args.cpu {
- Device::Cpu
- } else {
- default_device
- };
+ let device = candle_examples::device(args.cpu)?;
let default_model = "openai/whisper-tiny.en".to_string();
let path = std::path::PathBuf::from(default_model.clone());
let default_revision = "refs/pr/15".to_string();
diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs
index 8b137891..285aee04 100644
--- a/candle-examples/src/lib.rs
+++ b/candle-examples/src/lib.rs
@@ -1 +1,13 @@
+use candle::{Device, Result};
+pub fn device(cpu: bool) -> Result<Device> {
+ if cpu {
+ Ok(Device::Cpu)
+ } else {
+ let device = Device::cuda_if_available(0)?;
+ if !device.is_cuda() {
+ println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
+ }
+ Ok(device)
+ }
+}