diff options
Diffstat (limited to 'candle-examples/examples/ggml/main.rs')
-rw-r--r-- | candle-examples/examples/ggml/main.rs | 66 |
1 files changed, 56 insertions, 10 deletions
diff --git a/candle-examples/examples/ggml/main.rs b/candle-examples/examples/ggml/main.rs index 912bc53a..7d6ec2ca 100644 --- a/candle-examples/examples/ggml/main.rs +++ b/candle-examples/examples/ggml/main.rs @@ -2,6 +2,7 @@ use clap::Parser; use std::collections::HashMap; use std::io::Write; +use tokenizers::Tokenizer; use candle::quantized::ggml_file::Content; use candle::quantized::{QMatMul, QTensor}; @@ -52,6 +53,7 @@ struct LayerWeights { head_dim: usize, cos: Tensor, sin: Tensor, + kv_cache: Option<(Tensor, Tensor)>, } fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { @@ -75,7 +77,7 @@ impl LayerWeights { Ok(rope) } - fn forward_attn(&self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> { + fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> { let (b_sz, seq_len, n_embd) = x.dims3()?; let q = self.attention_wq.forward(x)?; let k = self.attention_wk.forward(x)?; @@ -94,7 +96,15 @@ impl LayerWeights { let q = self.apply_rotary_emb(&q, index_pos)?; let k = self.apply_rotary_emb(&k, index_pos)?; - // TODO: KV cache. + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((k_cache, v_cache)) => { + let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?; + let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?; + (k, v) + } + }; + self.kv_cache = Some((k.clone(), v.clone())); // If we start supporting MQA, we need to repeat the k and v tensors here. @@ -181,6 +191,7 @@ impl ModelWeights { head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, cos: cos.clone(), sin: sin.clone(), + kv_cache: None, }) } Ok(Self { @@ -209,7 +220,7 @@ impl ModelWeights { let (_b_sz, seq_len) = x.dims2()?; let mask = self.mask(seq_len)?; let mut layer_in = self.tok_embeddings.forward(x)?; - for (_layer_idx, layer) in self.layers.iter().enumerate() { + for layer in self.layers.iter_mut() { let x = layer_in; let residual = &x; let x = layer.attention_norm.forward(&x)?; @@ -237,7 +248,7 @@ impl ModelWeights { struct Args { /// GGML file to load, typically a .bin file generated by the quantize command from llama.cpp #[arg(long)] - model: String, + model: Option<String>, /// The initial prompt. #[arg(long)] @@ -249,7 +260,7 @@ struct Args { /// The tokenizer config in json format. #[arg(long)] - tokenizer: String, + tokenizer: Option<String>, /// The temperature used to generate samples. #[arg(long)] @@ -260,11 +271,36 @@ struct Args { seed: u64, } +impl Args { + fn tokenizer(&self) -> anyhow::Result<Tokenizer> { + let tokenizer_path = match &self.tokenizer { + Some(config) => std::path::PathBuf::from(config), + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("hf-internal-testing/llama-tokenizer".to_string()); + api.get("tokenizer.json")? + } + }; + Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg) + } + + fn model(&self) -> anyhow::Result<std::path::PathBuf> { + let model_path = match &self.model { + Some(config) => std::path::PathBuf::from(config), + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("TheBloke/Llama-2-7B-GGML".to_string()); + api.get("llama-2-7b.ggmlv3.q4_0.bin")? + } + }; + Ok(model_path) + } +} + fn main() -> anyhow::Result<()> { - use tokenizers::Tokenizer; let args = Args::parse(); - let mut file = std::fs::File::open(args.model)?; + let mut file = std::fs::File::open(&args.model()?)?; let start = std::time::Instant::now(); let model = Content::read(&mut file)?; @@ -293,7 +329,7 @@ fn main() -> anyhow::Result<()> { let mut model = ModelWeights::new(model)?; println!("model built"); - let tokenizer = Tokenizer::from_file(args.tokenizer).map_err(anyhow::Error::msg)?; + let tokenizer = args.tokenizer()?; let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str()); let mut tokens = tokenizer .encode(prompt, true) @@ -302,8 +338,11 @@ fn main() -> anyhow::Result<()> { .to_vec(); let mut index_pos = 0; let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); - for _index in 0..args.sample_len { - let context_size = tokens.len(); + let start_gen = std::time::Instant::now(); + let mut token_generated = 0; + print!("{prompt}"); + for index in 0..args.sample_len { + let context_size = if index == 0 { tokens.len() } else { 1 }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &Device::Cpu)?.unsqueeze(0)?; let logits = model.forward(&input, index_pos)?; @@ -311,6 +350,7 @@ fn main() -> anyhow::Result<()> { index_pos += ctxt.len(); let next_token = logits_processor.sample(&logits)?; + token_generated += 1; tokens.push(next_token); // Extracting the last token as a string is complicated, here we just apply some simple @@ -323,5 +363,11 @@ fn main() -> anyhow::Result<()> { std::io::stdout().flush()?; } } + let dt = start_gen.elapsed(); + println!( + "\n\n{} tokens generated ({} token/s)\n", + token_generated, + token_generated as f64 / dt.as_secs_f64(), + ); Ok(()) } |