summaryrefslogtreecommitdiff
path: root/candle-examples/examples/ggml/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/ggml/main.rs')
-rw-r--r--candle-examples/examples/ggml/main.rs66
1 files changed, 56 insertions, 10 deletions
diff --git a/candle-examples/examples/ggml/main.rs b/candle-examples/examples/ggml/main.rs
index 912bc53a..7d6ec2ca 100644
--- a/candle-examples/examples/ggml/main.rs
+++ b/candle-examples/examples/ggml/main.rs
@@ -2,6 +2,7 @@
use clap::Parser;
use std::collections::HashMap;
use std::io::Write;
+use tokenizers::Tokenizer;
use candle::quantized::ggml_file::Content;
use candle::quantized::{QMatMul, QTensor};
@@ -52,6 +53,7 @@ struct LayerWeights {
head_dim: usize,
cos: Tensor,
sin: Tensor,
+ kv_cache: Option<(Tensor, Tensor)>,
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
@@ -75,7 +77,7 @@ impl LayerWeights {
Ok(rope)
}
- fn forward_attn(&self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> {
+ fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.attention_wq.forward(x)?;
let k = self.attention_wk.forward(x)?;
@@ -94,7 +96,15 @@ impl LayerWeights {
let q = self.apply_rotary_emb(&q, index_pos)?;
let k = self.apply_rotary_emb(&k, index_pos)?;
- // TODO: KV cache.
+ let (k, v) = match &self.kv_cache {
+ None => (k, v),
+ Some((k_cache, v_cache)) => {
+ let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?;
+ let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?;
+ (k, v)
+ }
+ };
+ self.kv_cache = Some((k.clone(), v.clone()));
// If we start supporting MQA, we need to repeat the k and v tensors here.
@@ -181,6 +191,7 @@ impl ModelWeights {
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
cos: cos.clone(),
sin: sin.clone(),
+ kv_cache: None,
})
}
Ok(Self {
@@ -209,7 +220,7 @@ impl ModelWeights {
let (_b_sz, seq_len) = x.dims2()?;
let mask = self.mask(seq_len)?;
let mut layer_in = self.tok_embeddings.forward(x)?;
- for (_layer_idx, layer) in self.layers.iter().enumerate() {
+ for layer in self.layers.iter_mut() {
let x = layer_in;
let residual = &x;
let x = layer.attention_norm.forward(&x)?;
@@ -237,7 +248,7 @@ impl ModelWeights {
struct Args {
/// GGML file to load, typically a .bin file generated by the quantize command from llama.cpp
#[arg(long)]
- model: String,
+ model: Option<String>,
/// The initial prompt.
#[arg(long)]
@@ -249,7 +260,7 @@ struct Args {
/// The tokenizer config in json format.
#[arg(long)]
- tokenizer: String,
+ tokenizer: Option<String>,
/// The temperature used to generate samples.
#[arg(long)]
@@ -260,11 +271,36 @@ struct Args {
seed: u64,
}
+impl Args {
+ fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
+ let tokenizer_path = match &self.tokenizer {
+ Some(config) => std::path::PathBuf::from(config),
+ None => {
+ let api = hf_hub::api::sync::Api::new()?;
+ let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
+ api.get("tokenizer.json")?
+ }
+ };
+ Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
+ }
+
+ fn model(&self) -> anyhow::Result<std::path::PathBuf> {
+ let model_path = match &self.model {
+ Some(config) => std::path::PathBuf::from(config),
+ None => {
+ let api = hf_hub::api::sync::Api::new()?;
+ let api = api.model("TheBloke/Llama-2-7B-GGML".to_string());
+ api.get("llama-2-7b.ggmlv3.q4_0.bin")?
+ }
+ };
+ Ok(model_path)
+ }
+}
+
fn main() -> anyhow::Result<()> {
- use tokenizers::Tokenizer;
let args = Args::parse();
- let mut file = std::fs::File::open(args.model)?;
+ let mut file = std::fs::File::open(&args.model()?)?;
let start = std::time::Instant::now();
let model = Content::read(&mut file)?;
@@ -293,7 +329,7 @@ fn main() -> anyhow::Result<()> {
let mut model = ModelWeights::new(model)?;
println!("model built");
- let tokenizer = Tokenizer::from_file(args.tokenizer).map_err(anyhow::Error::msg)?;
+ let tokenizer = args.tokenizer()?;
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
@@ -302,8 +338,11 @@ fn main() -> anyhow::Result<()> {
.to_vec();
let mut index_pos = 0;
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
- for _index in 0..args.sample_len {
- let context_size = tokens.len();
+ let start_gen = std::time::Instant::now();
+ let mut token_generated = 0;
+ print!("{prompt}");
+ for index in 0..args.sample_len {
+ let context_size = if index == 0 { tokens.len() } else { 1 };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, index_pos)?;
@@ -311,6 +350,7 @@ fn main() -> anyhow::Result<()> {
index_pos += ctxt.len();
let next_token = logits_processor.sample(&logits)?;
+ token_generated += 1;
tokens.push(next_token);
// Extracting the last token as a string is complicated, here we just apply some simple
@@ -323,5 +363,11 @@ fn main() -> anyhow::Result<()> {
std::io::stdout().flush()?;
}
}
+ let dt = start_gen.elapsed();
+ println!(
+ "\n\n{} tokens generated ({} token/s)\n",
+ token_generated,
+ token_generated as f64 / dt.as_secs_f64(),
+ );
Ok(())
}