diff options
Diffstat (limited to 'candle-core/examples/llama/main.rs')
-rw-r--r-- | candle-core/examples/llama/main.rs | 108 |
1 files changed, 72 insertions, 36 deletions
diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index e936d6b3..9d70921c 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -24,6 +24,7 @@ mod var_store; mod weights; const CONTEXT_SIZE: usize = 512; +const USE_KV_CACHE: bool = true; const START_PROMPT: &str = r" EDWARD: I wonder how our princely father 'scaped, @@ -218,13 +219,16 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> #[derive(Clone)] struct Cache { masks: Arc<Mutex<HashMap<usize, Tensor>>>, + #[allow(clippy::type_complexity)] + kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>, device: Device, } impl Cache { - fn new(device: &Device) -> Self { + fn new(config: &Config, device: &Device) -> Self { Self { masks: Arc::new(Mutex::new(HashMap::new())), + kvs: Arc::new(Mutex::new(vec![None; config.n_layer])), device: device.clone(), } } @@ -249,7 +253,6 @@ struct CausalSelfAttention { c_attn: Linear, c_proj: Linear, n_head: usize, - // n_embd: usize, cache: Cache, } @@ -265,6 +268,7 @@ impl CausalSelfAttention { fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { let mut dims = x.dims().to_vec(); + let freqs_cis = freqs_cis.narrow(1, freqs_cis.dims()[1] - dims[1], dims[1])?; let v = dims.pop().unwrap(); dims.push(v / 2); dims.push(2); @@ -285,7 +289,7 @@ impl CausalSelfAttention { Ok(rope) } - fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { + fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> { let (t, c) = x.shape().r2()?; let qkv = self.c_attn.forward(x)?; let qkv = qkv.to_dtype(DType::F32)?; @@ -296,9 +300,31 @@ impl CausalSelfAttention { let target_dim = [t, self.n_head, c / self.n_head]; let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?; let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?; - let v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?; + let mut v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?; let q = self.apply_rotary_emb(&q, freqs_cis)?; - let k = self.apply_rotary_emb(&k, freqs_cis)?; + let mut k = self.apply_rotary_emb(&k, freqs_cis)?; + + if USE_KV_CACHE { + let mut cache = self.cache.kvs.lock().unwrap(); + if let Some((cache_k, cache_v)) = &cache[block_idx] { + k = Tensor::cat(&[cache_k, &k], 1)?; + v = Tensor::cat(&[cache_v, &v], 1)?; + let k_seq_len = k.dims()[1]; + if k_seq_len > CONTEXT_SIZE { + k = k + .narrow(1, k_seq_len - CONTEXT_SIZE, CONTEXT_SIZE)? + .contiguous()? + } + let v_seq_len = v.dims()[1]; + if v_seq_len > CONTEXT_SIZE { + v = v + .narrow(1, v_seq_len - CONTEXT_SIZE, CONTEXT_SIZE)? + .contiguous()? + } + } + cache[block_idx] = Some((k.clone(), v.clone())) + } + let k_shape = k.shape(); let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?; let mask = self.cache.mask(t)?.broadcast_as(att.shape())?; @@ -330,8 +356,11 @@ impl Block { } } - fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { - let x = (self.attn.forward(&self.rms_1.forward(x)?, freqs_cis)? + x)?; + fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> { + let x = (self + .attn + .forward(&self.rms_1.forward(x)?, freqs_cis, block_idx)? + + x)?; let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?; Ok(x) } @@ -358,8 +387,8 @@ impl Llama { // TODO: Support for mini-batches? (i.e. r2) let t = x.shape().r1()?; let mut x = self.wte.forward(x)?; - for block in self.blocks.iter() { - x = block.forward(&x, freqs_cis)?; + for (block_idx, block) in self.blocks.iter().enumerate() { + x = block.forward(&x, freqs_cis, block_idx)?; } let x = self.ln_f.forward(&x)?; let x = x.narrow(0, t - 1, 1)?; @@ -400,7 +429,7 @@ struct Args { /// Use npy instead of safetensors #[arg(long)] - npy: bool, + npy: Option<String>, /// The temperature used to generate samples. #[arg(long)] @@ -426,33 +455,35 @@ async fn main() -> Result<()> { Device::new_cuda(0)? }; let config = Config::config_7b(); - let cache = Cache::new(&device); + let cache = Cache::new(&config, &device); let start = std::time::Instant::now(); - let (llama, tokenizer_filename) = if args.npy { - println!("building the model (NPY)"); - ( - Llama::load_npy(&device, "/data/llama.npz", &cache, &config)?, - std::path::Path::new("llama-tokenizer.json").to_path_buf(), - ) - } else { - let api = Api::new()?; - let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model); - println!("building the model"); - let tokenizer_filename = api.get(&repo, "tokenizer.json").await?; - let mut filenames = vec![]; - for rfilename in [ - "model-00001-of-00002.safetensors", - "model-00002-of-00002.safetensors", - ] { - let filename = api.get(&repo, rfilename).await?; - filenames.push(filename); + let (llama, tokenizer_filename) = match args.npy { + Some(npy) => { + println!("building the model (NPY)"); + let weights = Llama::load_npy(&device, &npy, &cache, &config)?; + let token_path = std::path::Path::new("llama-tokenizer.json").to_path_buf(); + (weights, token_path) + } + None => { + let api = Api::new()?; + let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model); + println!("building the model"); + let tokenizer_filename = api.get(&repo, "tokenizer.json").await?; + let mut filenames = vec![]; + for rfilename in [ + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", + ] { + let filename = api.get(&repo, rfilename).await?; + filenames.push(filename); + } + + println!("building the model (SF)"); + ( + Llama::load(&device, &filenames, &cache, &config)?, + tokenizer_filename, + ) } - - println!("building the model (SF)"); - ( - Llama::load(&device, &filenames, &cache, &config)?, - tokenizer_filename, - ) }; println!("Loaded in {:?}", start.elapsed()); let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; @@ -470,7 +501,12 @@ async fn main() -> Result<()> { let start_gen = std::time::Instant::now(); for index in 0..args.sample_len { let start_gen = std::time::Instant::now(); - let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..]; + let context_size = if USE_KV_CACHE && index > 0 { + 1 + } else { + CONTEXT_SIZE + }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?; let logits = llama.forward(&input, &freqs_cis)?; |