diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-05-18 15:58:18 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-18 15:58:18 +0200 |
commit | 01545f73038cb8c90426214ddf4bcedd59e291e8 (patch) | |
tree | 2f2af0905b404bd42ee22962f94ea812e3caaa9e /candle-transformers | |
parent | 349c3e806a15399df8289c41b2e24c3fa24b6d84 (diff) | |
download | candle-01545f73038cb8c90426214ddf4bcedd59e291e8.tar.gz candle-01545f73038cb8c90426214ddf4bcedd59e291e8.tar.bz2 candle-01545f73038cb8c90426214ddf4bcedd59e291e8.zip |
Add a slice_set op. (#2193)
* Add a slice_set op.
* Add some testing.
* Add the dedicated kv-cache module.
* Derive debug and clone.
* Expose more kv-cache functions.
* Return the current data when appending.
* Use the new cache in the quantized phi3 model.
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/quantized_phi3.rs | 41 |
1 files changed, 19 insertions, 22 deletions
diff --git a/candle-transformers/src/models/quantized_phi3.rs b/candle-transformers/src/models/quantized_phi3.rs index ef404ca0..04aff8b5 100644 --- a/candle-transformers/src/models/quantized_phi3.rs +++ b/candle-transformers/src/models/quantized_phi3.rs @@ -3,9 +3,7 @@ use std::collections::HashMap; use candle::quantized::gguf_file; use candle::quantized::QTensor; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; -use candle_nn::{Embedding, RmsNorm}; - -pub const MAX_SEQ_LEN: usize = 4096; +use candle_nn::{kv_cache::KvCache, Embedding, RmsNorm}; #[derive(Debug, Clone)] struct QLinear { @@ -70,7 +68,7 @@ struct LayerWeights { cos: Tensor, sin: Tensor, neg_inf: Tensor, - kv_cache: Option<(Tensor, Tensor)>, + kv_cache: KvCache, span_attn: tracing::Span, span_rot: tracing::Span, } @@ -122,19 +120,7 @@ impl LayerWeights { let q = self.apply_rotary_emb(&q, index_pos)?.contiguous()?; let k = self.apply_rotary_emb(&k, index_pos)?; - let (k, v) = match &self.kv_cache { - None => (k.contiguous()?, v.contiguous()?), - Some((k_cache, v_cache)) => { - if index_pos == 0 { - (k.contiguous()?, v.contiguous()?) - } else { - let k = Tensor::cat(&[k_cache, &k], 2)?; - let v = Tensor::cat(&[v_cache, &v], 2)?; - (k.contiguous()?, v.contiguous()?) - } - } - }; - self.kv_cache = Some((k.clone(), v.clone())); + let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?; let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?; let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?; @@ -169,6 +155,7 @@ pub struct ModelWeights { fn precomput_freqs_cis( head_dim: usize, + max_seq_len: usize, freq_base: f32, device: &Device, ) -> Result<(Tensor, Tensor)> { @@ -177,9 +164,9 @@ fn precomput_freqs_cis( .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) .collect(); let theta = Tensor::new(theta.as_slice(), device)?; - let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? + let idx_theta = Tensor::arange(0, max_seq_len as u32, device)? .to_dtype(DType::F32)? - .reshape((MAX_SEQ_LEN, 1))? + .reshape((max_seq_len, 1))? .matmul(&theta.reshape((1, theta.elem_count()))?)?; let cos = idx_theta.cos()?; let sin = idx_theta.sin()?; @@ -188,6 +175,7 @@ fn precomput_freqs_cis( impl ModelWeights { pub fn from_gguf<R: std::io::Seek + std::io::Read>( + batch_size: usize, ct: gguf_file::Content, reader: &mut R, device: &Device, @@ -202,16 +190,19 @@ impl ModelWeights { let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize; let block_count = md_get("phi3.block_count")?.to_u32()? as usize; let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize; + let max_seq_len = md_get("phi3.context_length")?.to_u32()? as usize; + let head_dim = embedding_length / head_count; let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize; let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize; let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; - let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device)?; + let (cos, sin) = precomput_freqs_cis(rope_dim, max_seq_len, 10_000., device)?; let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?; let output = QLinear::new(&ct, reader, "output", device)?; + let mut layers = Vec::with_capacity(block_count); for layer_idx in 0..block_count { let prefix = format!("blk.{layer_idx}"); @@ -232,6 +223,12 @@ impl ModelWeights { )?; let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let kv_cache = KvCache::new( + 2, + (batch_size, head_count_kv, max_seq_len, head_dim), + DType::F32, + device, + )?; layers.push(LayerWeights { attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?, attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?, @@ -240,11 +237,11 @@ impl ModelWeights { mlp, n_head: head_count, n_kv_head: head_count_kv, - head_dim: embedding_length / head_count, + head_dim, cos: cos.clone(), sin: sin.clone(), neg_inf: neg_inf.clone(), - kv_cache: None, + kv_cache, span_attn, span_rot, }) |