diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-05-23 17:07:21 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-23 17:07:21 +0200 |
commit | 45e235a7473d473df5c1e50f55504a97e28be822 (patch) | |
tree | 6e8518249f1bbbc634431327b96d24f3e270ebc5 /candle-nn | |
parent | 31cf64147b9ab4a3d68849bef0ea59bdb0c113d6 (diff) | |
download | candle-45e235a7473d473df5c1e50f55504a97e28be822.tar.gz candle-45e235a7473d473df5c1e50f55504a97e28be822.tar.bz2 candle-45e235a7473d473df5c1e50f55504a97e28be822.zip |
Simplify the KvCache api. (#2207)
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/kv_cache.rs | 89 |
1 files changed, 53 insertions, 36 deletions
diff --git a/candle-nn/src/kv_cache.rs b/candle-nn/src/kv_cache.rs index 10e9fe5a..eb5dbfdb 100644 --- a/candle-nn/src/kv_cache.rs +++ b/candle-nn/src/kv_cache.rs @@ -1,30 +1,25 @@ -use candle::{DType, Device, Result, Shape, Tensor}; +use candle::{Result, Tensor}; #[derive(Debug, Clone)] pub struct Cache { - all_data: Tensor, + // all_data is an option on a Tensor, this makes it possible to only create the actual tensor + // on the first call where the batch size is easily known. + // Also this makes it safe to clone a KvCache that has been reseted (as in it will not share + // its internal state with the cloned instance). + all_data: Option<Tensor>, dim: usize, current_seq_len: usize, max_seq_len: usize, } impl Cache { - pub fn new<S: Into<Shape>, D: candle::shape::Dim>( - dim: D, - shape: S, - dtype: DType, - dev: &Device, - ) -> Result<Self> { - let shape = shape.into(); - let dim = dim.to_index(&shape, "kv-cache")?; - let max_seq_len = shape.dims()[dim]; - let all_data = Tensor::zeros(shape, dtype, dev)?; - Ok(Self { - all_data, + pub fn new(dim: usize, max_seq_len: usize) -> Self { + Self { + all_data: None, dim, current_seq_len: 0, max_seq_len, - }) + } } pub fn dim(&self) -> usize { @@ -39,20 +34,34 @@ impl Cache { self.max_seq_len } - pub fn all_data(&self) -> &Tensor { + pub fn all_data(&self) -> &Option<Tensor> { &self.all_data } - pub fn current_data(&self) -> Result<Tensor> { - self.all_data.narrow(self.dim, 0, self.current_seq_len) + pub fn current_data(&self) -> Result<Option<Tensor>> { + let data = match self.all_data.as_ref() { + None => None, + Some(d) => Some(d.narrow(self.dim, 0, self.current_seq_len)?), + }; + Ok(data) } pub fn reset(&mut self) { - self.current_seq_len = 0 + self.current_seq_len = 0; + self.all_data = None; } pub fn append(&mut self, src: &Tensor) -> Result<()> { let seq_len = src.dim(self.dim)?; + // This doesn't seem very idiomatic but because the creation can fail, it's tricky to use + // self.all_data.get_or_insert_with. + if self.all_data.is_none() { + let mut shape = src.dims().to_vec(); + shape[self.dim] = self.max_seq_len; + let ad = Tensor::zeros(shape, src.dtype(), src.device())?; + self.all_data = Some(ad) + }; + let ad = self.all_data.as_mut().unwrap(); if self.current_seq_len + seq_len > self.max_seq_len { candle::bail!( "kv-cache: above max-seq-len {}+{seq_len}>{}", @@ -60,8 +69,7 @@ impl Cache { self.max_seq_len ) } - self.all_data - .slice_set(src, self.dim, self.current_seq_len)?; + ad.slice_set(src, self.dim, self.current_seq_len)?; self.current_seq_len += seq_len; Ok(()) } @@ -74,17 +82,10 @@ pub struct KvCache { } impl KvCache { - pub fn new<S: Into<Shape>, D: candle::shape::Dim>( - dim: D, - shape: S, - dtype: DType, - dev: &Device, - ) -> Result<Self> { - let shape = shape.into(); - let dim = dim.to_index(&shape, "kv-cache")?; - let k = Cache::new(dim, &shape, dtype, dev)?; - let v = Cache::new(dim, &shape, dtype, dev)?; - Ok(Self { k, v }) + pub fn new(dim: usize, max_seq_len: usize) -> Self { + let k = Cache::new(dim, max_seq_len); + let v = Cache::new(dim, max_seq_len); + Self { k, v } } pub fn k_cache(&self) -> &Cache { @@ -103,19 +104,35 @@ impl KvCache { &mut self.v } - pub fn k(&self) -> Result<Tensor> { + pub fn k(&self) -> Result<Option<Tensor>> { self.k.current_data() } - pub fn v(&self) -> Result<Tensor> { + pub fn v(&self) -> Result<Option<Tensor>> { self.v.current_data() } pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { self.k.append(k)?; self.v.append(v)?; - let k = self.k.current_data()?; - let v = self.v.current_data()?; + let out_k = self.k.current_data()?; + let out_v = self.v.current_data()?; + let k = match out_k { + None => { + let mut shape = k.dims().to_vec(); + shape[self.k.dim] = 0; + Tensor::zeros(shape, k.dtype(), k.device())? + } + Some(k) => k, + }; + let v = match out_v { + None => { + let mut shape = v.dims().to_vec(); + shape[self.k.dim] = 0; + Tensor::zeros(shape, v.dtype(), v.device())? + } + Some(v) => v, + }; Ok((k, v)) } |