summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama/main.rs')
-rw-r--r--candle-examples/examples/llama/main.rs19
1 files changed, 13 insertions, 6 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index c2ed0e25..251c184b 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -33,6 +33,8 @@ enum Which {
V2,
#[value(name = "solar-10.7b")]
Solar10_7B,
+ #[value(name = "tiny-llama-1.1b-chat")]
+ TinyLlama1_1BChat,
}
#[derive(Parser, Debug)]
@@ -124,6 +126,7 @@ fn main() -> Result<()> {
Which::V1 => "Narsil/amall-7b".to_string(),
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
+ Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
});
println!("loading the model weights from {model_id}");
let revision = args.revision.unwrap_or("main".to_string());
@@ -134,8 +137,12 @@ fn main() -> Result<()> {
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = config.into_config(args.use_flash_attn);
- let filenames =
- candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
+ let filenames = match args.which {
+ Which::V1 | Which::V2 | Which::Solar10_7B => {
+ candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
+ }
+ Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
+ };
println!("building the model");
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
@@ -158,14 +165,14 @@ fn main() -> Result<()> {
let mut index_pos = 0;
let mut token_generated = 0;
for index in 0..args.sample_len {
- let context_size = if cache.use_kv_cache && index > 0 {
- 1
+ let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
+ (1, index_pos)
} else {
- tokens.len()
+ (tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
- let logits = llama.forward(&input, index_pos)?;
+ let logits = llama.forward(&input, context_index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits