summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-05-18 15:58:18 +0200
committerGitHub <noreply@github.com>2024-05-18 15:58:18 +0200
commit01545f73038cb8c90426214ddf4bcedd59e291e8 (patch)
tree2f2af0905b404bd42ee22962f94ea812e3caaa9e /candle-transformers
parent349c3e806a15399df8289c41b2e24c3fa24b6d84 (diff)
downloadcandle-01545f73038cb8c90426214ddf4bcedd59e291e8.tar.gz
candle-01545f73038cb8c90426214ddf4bcedd59e291e8.tar.bz2
candle-01545f73038cb8c90426214ddf4bcedd59e291e8.zip
Add a slice_set op. (#2193)
* Add a slice_set op. * Add some testing. * Add the dedicated kv-cache module. * Derive debug and clone. * Expose more kv-cache functions. * Return the current data when appending. * Use the new cache in the quantized phi3 model.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/quantized_phi3.rs41
1 files changed, 19 insertions, 22 deletions
diff --git a/candle-transformers/src/models/quantized_phi3.rs b/candle-transformers/src/models/quantized_phi3.rs
index ef404ca0..04aff8b5 100644
--- a/candle-transformers/src/models/quantized_phi3.rs
+++ b/candle-transformers/src/models/quantized_phi3.rs
@@ -3,9 +3,7 @@ use std::collections::HashMap;
use candle::quantized::gguf_file;
use candle::quantized::QTensor;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
-use candle_nn::{Embedding, RmsNorm};
-
-pub const MAX_SEQ_LEN: usize = 4096;
+use candle_nn::{kv_cache::KvCache, Embedding, RmsNorm};
#[derive(Debug, Clone)]
struct QLinear {
@@ -70,7 +68,7 @@ struct LayerWeights {
cos: Tensor,
sin: Tensor,
neg_inf: Tensor,
- kv_cache: Option<(Tensor, Tensor)>,
+ kv_cache: KvCache,
span_attn: tracing::Span,
span_rot: tracing::Span,
}
@@ -122,19 +120,7 @@ impl LayerWeights {
let q = self.apply_rotary_emb(&q, index_pos)?.contiguous()?;
let k = self.apply_rotary_emb(&k, index_pos)?;
- let (k, v) = match &self.kv_cache {
- None => (k.contiguous()?, v.contiguous()?),
- Some((k_cache, v_cache)) => {
- if index_pos == 0 {
- (k.contiguous()?, v.contiguous()?)
- } else {
- let k = Tensor::cat(&[k_cache, &k], 2)?;
- let v = Tensor::cat(&[v_cache, &v], 2)?;
- (k.contiguous()?, v.contiguous()?)
- }
- }
- };
- self.kv_cache = Some((k.clone(), v.clone()));
+ let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
@@ -169,6 +155,7 @@ pub struct ModelWeights {
fn precomput_freqs_cis(
head_dim: usize,
+ max_seq_len: usize,
freq_base: f32,
device: &Device,
) -> Result<(Tensor, Tensor)> {
@@ -177,9 +164,9 @@ fn precomput_freqs_cis(
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
- let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
+ let idx_theta = Tensor::arange(0, max_seq_len as u32, device)?
.to_dtype(DType::F32)?
- .reshape((MAX_SEQ_LEN, 1))?
+ .reshape((max_seq_len, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
@@ -188,6 +175,7 @@ fn precomput_freqs_cis(
impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
+ batch_size: usize,
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
@@ -202,16 +190,19 @@ impl ModelWeights {
let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("phi3.block_count")?.to_u32()? as usize;
let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize;
+ let max_seq_len = md_get("phi3.context_length")?.to_u32()? as usize;
+ let head_dim = embedding_length / head_count;
let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize;
let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize;
let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
- let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device)?;
+ let (cos, sin) = precomput_freqs_cis(rope_dim, max_seq_len, 10_000., device)?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?;
let output = QLinear::new(&ct, reader, "output", device)?;
+
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}");
@@ -232,6 +223,12 @@ impl ModelWeights {
)?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
+ let kv_cache = KvCache::new(
+ 2,
+ (batch_size, head_count_kv, max_seq_len, head_dim),
+ DType::F32,
+ device,
+ )?;
layers.push(LayerWeights {
attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,
@@ -240,11 +237,11 @@ impl ModelWeights {
mlp,
n_head: head_count,
n_kv_head: head_count_kv,
- head_dim: embedding_length / head_count,
+ head_dim,
cos: cos.clone(),
sin: sin.clone(),
neg_inf: neg_inf.clone(),
- kv_cache: None,
+ kv_cache,
span_attn,
span_rot,
})