summaryrefslogtreecommitdiff
path: root/candle-examples/examples/ggml/main.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-16 16:41:06 +0100
committerGitHub <noreply@github.com>2023-08-16 16:41:06 +0100
commit2e206e269da311cb0c3bde164e6c2ecb9286034e (patch)
tree8b3a5c1e542596fcfe337bee41e3c0716c1718bb /candle-examples/examples/ggml/main.rs
parent575e88a999c107f314263dca4c3d590a8a16c4a3 (diff)
downloadcandle-2e206e269da311cb0c3bde164e6c2ecb9286034e.tar.gz
candle-2e206e269da311cb0c3bde164e6c2ecb9286034e.tar.bz2
candle-2e206e269da311cb0c3bde164e6c2ecb9286034e.zip
Add the model argument. (#471)
Diffstat (limited to 'candle-examples/examples/ggml/main.rs')
-rw-r--r--candle-examples/examples/ggml/main.rs16
1 files changed, 14 insertions, 2 deletions
diff --git a/candle-examples/examples/ggml/main.rs b/candle-examples/examples/ggml/main.rs
index 68e2267c..7d6ec2ca 100644
--- a/candle-examples/examples/ggml/main.rs
+++ b/candle-examples/examples/ggml/main.rs
@@ -248,7 +248,7 @@ impl ModelWeights {
struct Args {
/// GGML file to load, typically a .bin file generated by the quantize command from llama.cpp
#[arg(long)]
- model: String,
+ model: Option<String>,
/// The initial prompt.
#[arg(long)]
@@ -283,12 +283,24 @@ impl Args {
};
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
}
+
+ fn model(&self) -> anyhow::Result<std::path::PathBuf> {
+ let model_path = match &self.model {
+ Some(config) => std::path::PathBuf::from(config),
+ None => {
+ let api = hf_hub::api::sync::Api::new()?;
+ let api = api.model("TheBloke/Llama-2-7B-GGML".to_string());
+ api.get("llama-2-7b.ggmlv3.q4_0.bin")?
+ }
+ };
+ Ok(model_path)
+ }
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
- let mut file = std::fs::File::open(&args.model)?;
+ let mut file = std::fs::File::open(&args.model()?)?;
let start = std::time::Instant::now();
let model = Content::read(&mut file)?;