diff options
Diffstat (limited to 'candle-examples/examples/llama_multiprocess/main.rs')
-rw-r--r-- | candle-examples/examples/llama_multiprocess/main.rs | 39 |
1 files changed, 10 insertions, 29 deletions
diff --git a/candle-examples/examples/llama_multiprocess/main.rs b/candle-examples/examples/llama_multiprocess/main.rs index f9e87432..679e5faa 100644 --- a/candle-examples/examples/llama_multiprocess/main.rs +++ b/candle-examples/examples/llama_multiprocess/main.rs @@ -20,7 +20,7 @@ use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; use cudarc::driver::safe::CudaDevice; use cudarc::nccl::safe::{Comm, Id}; -use hf_hub::{api::sync::Api, Repo, RepoType}; +use hf_hub::api::sync::Api; use std::io::Write; use std::rc::Rc; @@ -83,10 +83,6 @@ Upon my target three fair-shining suns. #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { - /// Run on CPU rather than on GPU. - #[arg(long)] - cpu: bool, - #[arg(long)] num_shards: usize, @@ -113,15 +109,8 @@ struct Args { #[arg(long)] prompt: Option<String>, - /// Use f32 computations rather than f16. - #[arg(long)] - use_f32: bool, - #[arg(long)] model_id: Option<String>, - - #[arg(long)] - v2: bool, } fn main() -> Result<()> { @@ -130,26 +119,22 @@ fn main() -> Result<()> { let args = Args::parse(); let config = Config::config_7b(); - let dtype = if args.use_f32 { DType::F32 } else { DType::F16 }; + let dtype = DType::F16; let api = Api::new()?; - let model_id = args.model_id.unwrap_or_else(|| { - if args.v2 { - "meta-llama/Llama-2-7b-hf".to_string() - } else { - "Narsil/amall-7b".to_string() - } - }); + let model_id = args + .model_id + .unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string()); println!("loading the model weights from {model_id}"); - let repo = Repo::new(model_id, RepoType::Model); - let tokenizer_filename = api.get(&repo, "tokenizer.json")?; + let api = api.model(model_id); + let tokenizer_filename = api.get("tokenizer.json")?; let mut filenames = vec![]; for rfilename in [ "model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors", ] { - let filename = api.get(&repo, rfilename)?; + let filename = api.get(rfilename)?; filenames.push(filename); } @@ -203,7 +188,7 @@ fn main() -> Result<()> { println!("Rank {rank:?} spawned"); let device = Device::new_cuda(i)?; - let cache = model::Cache::new(!args.no_kv_cache, &config, &device)?; + let cache = model::Cache::new(&config, &device)?; println!("building the model"); let handles = filenames @@ -233,11 +218,7 @@ fn main() -> Result<()> { let mut index_pos = 0; for index in 0..args.sample_len { let start_gen = std::time::Instant::now(); - let context_size = if cache.use_kv_cache && index > 0 { - 1 - } else { - tokens.len() - }; + let context_size = if index > 0 { 1 } else { tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let logits = llama.forward(&input, index_pos)?; |