summaryrefslogtreecommitdiff
path: root/candle-examples/examples/bert
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-07-06 15:15:25 +0200
committerNicolas Patry <patry.nicolas@protonmail.com>2023-07-06 15:15:25 +0200
commit115629fe08e989277bca478339dc16cb1bf731e4 (patch)
treed42197ad042de0ce0e5bf5ec278c3f5baeeb7f61 /candle-examples/examples/bert
parentdd60bd84bb4c3b52698d971ca383a19064d0c7e0 (diff)
downloadcandle-115629fe08e989277bca478339dc16cb1bf731e4.tar.gz
candle-115629fe08e989277bca478339dc16cb1bf731e4.tar.bz2
candle-115629fe08e989277bca478339dc16cb1bf731e4.zip
Creating new sync Api for `candle-hub`.
- `api::Api` -> `api::tokio::api` (And created new `api::sync::Api`). - Remove `tokio` from all our examples. - Using similar codebase for now instead of ureq (for simplicity).
Diffstat (limited to 'candle-examples/examples/bert')
-rw-r--r--candle-examples/examples/bert/main.rs15
1 files changed, 7 insertions, 8 deletions
diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs
index bf99b1bf..72cc65b3 100644
--- a/candle-examples/examples/bert/main.rs
+++ b/candle-examples/examples/bert/main.rs
@@ -5,7 +5,7 @@ extern crate intel_mkl_src;
use anyhow::{anyhow, Error as E, Result};
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
-use candle_hub::{api::Api, Cache, Repo, RepoType};
+use candle_hub::{api::sync::Api, Cache, Repo, RepoType};
use clap::Parser;
use serde::Deserialize;
use std::collections::HashMap;
@@ -645,7 +645,7 @@ struct Args {
}
impl Args {
- async fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
+ fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
let device = if self.cpu {
Device::Cpu
} else {
@@ -677,9 +677,9 @@ impl Args {
} else {
let api = Api::new()?;
(
- api.get(&repo, "config.json").await?,
- api.get(&repo, "tokenizer.json").await?,
- api.get(&repo, "model.safetensors").await?,
+ api.get(&repo, "config.json")?,
+ api.get(&repo, "tokenizer.json")?,
+ api.get(&repo, "model.safetensors")?,
)
};
let config = std::fs::read_to_string(config_filename)?;
@@ -694,12 +694,11 @@ impl Args {
}
}
-#[tokio::main]
-async fn main() -> Result<()> {
+fn main() -> Result<()> {
let start = std::time::Instant::now();
let args = Args::parse();
- let (model, mut tokenizer) = args.build_model_and_tokenizer().await?;
+ let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
if let Some(prompt) = args.prompt {