summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama_multiprocess/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama_multiprocess/main.rs')
-rw-r--r--candle-examples/examples/llama_multiprocess/main.rs39
1 files changed, 10 insertions, 29 deletions
diff --git a/candle-examples/examples/llama_multiprocess/main.rs b/candle-examples/examples/llama_multiprocess/main.rs
index f9e87432..679e5faa 100644
--- a/candle-examples/examples/llama_multiprocess/main.rs
+++ b/candle-examples/examples/llama_multiprocess/main.rs
@@ -20,7 +20,7 @@ use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use cudarc::driver::safe::CudaDevice;
use cudarc::nccl::safe::{Comm, Id};
-use hf_hub::{api::sync::Api, Repo, RepoType};
+use hf_hub::api::sync::Api;
use std::io::Write;
use std::rc::Rc;
@@ -83,10 +83,6 @@ Upon my target three fair-shining suns.
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
- /// Run on CPU rather than on GPU.
- #[arg(long)]
- cpu: bool,
-
#[arg(long)]
num_shards: usize,
@@ -113,15 +109,8 @@ struct Args {
#[arg(long)]
prompt: Option<String>,
- /// Use f32 computations rather than f16.
- #[arg(long)]
- use_f32: bool,
-
#[arg(long)]
model_id: Option<String>,
-
- #[arg(long)]
- v2: bool,
}
fn main() -> Result<()> {
@@ -130,26 +119,22 @@ fn main() -> Result<()> {
let args = Args::parse();
let config = Config::config_7b();
- let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
+ let dtype = DType::F16;
let api = Api::new()?;
- let model_id = args.model_id.unwrap_or_else(|| {
- if args.v2 {
- "meta-llama/Llama-2-7b-hf".to_string()
- } else {
- "Narsil/amall-7b".to_string()
- }
- });
+ let model_id = args
+ .model_id
+ .unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string());
println!("loading the model weights from {model_id}");
- let repo = Repo::new(model_id, RepoType::Model);
- let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
+ let api = api.model(model_id);
+ let tokenizer_filename = api.get("tokenizer.json")?;
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
- let filename = api.get(&repo, rfilename)?;
+ let filename = api.get(rfilename)?;
filenames.push(filename);
}
@@ -203,7 +188,7 @@ fn main() -> Result<()> {
println!("Rank {rank:?} spawned");
let device = Device::new_cuda(i)?;
- let cache = model::Cache::new(!args.no_kv_cache, &config, &device)?;
+ let cache = model::Cache::new(&config, &device)?;
println!("building the model");
let handles = filenames
@@ -233,11 +218,7 @@ fn main() -> Result<()> {
let mut index_pos = 0;
for index in 0..args.sample_len {
let start_gen = std::time::Instant::now();
- let context_size = if cache.use_kv_cache && index > 0 {
- 1
- } else {
- tokens.len()
- };
+ let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, index_pos)?;