summaryrefslogtreecommitdiff
path: root/candle-examples/examples/bert/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/bert/main.rs')
-rw-r--r--candle-examples/examples/bert/main.rs15
1 files changed, 1 insertions, 14 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index aae8bc50..dca6721b 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -495,20 +495,7 @@ struct Args {
impl Args {
fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
- #[cfg(feature = "cuda")]
- let default_device = Device::new_cuda(0)?;
-
- #[cfg(not(feature = "cuda"))]
- let default_device = {
- println!("Running on CPU, to run on GPU, run this example with `--features cuda`");
- Device::Cpu
- };
-
- let device = if self.cpu {
- Device::Cpu
- } else {
- default_device
- };
+ let device = candle_examples::device(self.cpu)?;
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
let default_revision = "refs/pr/21".to_string();
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {