summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-09-23 13:14:32 +0200
committerGitHub <noreply@github.com>2024-09-23 13:14:32 +0200
commitd01207dbf3fb0ad614e7915c8f5706fbc09902fb (patch)
treebd8303c3d7acd097777485f220acd14aac73bd83 /candle-nn
parent8097559c1a293d26cf9cd65d92d9cb4696197c2e (diff)
downloadcandle-d01207dbf3fb0ad614e7915c8f5706fbc09902fb.tar.gz
candle-d01207dbf3fb0ad614e7915c8f5706fbc09902fb.tar.bz2
candle-d01207dbf3fb0ad614e7915c8f5706fbc09902fb.zip
Add a RotatingKVCache. (#2493)
* Add a RotatingKVCache. * Add some KvCache tests. * Test the reset too. * More kv-cache testing. * More tests for the rotating kv-cache. * Improve the api for the rotating cache so that the whole src tensor gets returned when it's overlarge. * Handle contiguity + bugfix + use in mimi. * Add a way to test the mimi streaming mode. * Mimi streaming fixes. * More rotating kv-cache. * Fix the attn mask generation. * Handle the abs case. * Add some tests for the generated mask.
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/kv_cache.rs224
-rw-r--r--candle-nn/tests/kv_cache.rs110
2 files changed, 333 insertions, 1 deletions
diff --git a/candle-nn/src/kv_cache.rs b/candle-nn/src/kv_cache.rs
index eb5dbfdb..68addb98 100644
--- a/candle-nn/src/kv_cache.rs
+++ b/candle-nn/src/kv_cache.rs
@@ -1,4 +1,4 @@
-use candle::{Result, Tensor};
+use candle::{Device, Result, Tensor};
#[derive(Debug, Clone)]
pub struct Cache {
@@ -145,3 +145,225 @@ impl KvCache {
self.v.reset();
}
}
+
+#[derive(Debug, Clone)]
+pub struct RotatingCache {
+ all_data: Option<Tensor>,
+ dim: usize,
+ // `offset` is the current write index in the buffer
+ offset: usize,
+ // The total size of the sequence seen so far.
+ current_seq_len: usize,
+ // max_seq_len is the size of the rotating buffer, it is actually allowed for the full
+ // sequence to grow past this limit.
+ max_seq_len: usize,
+}
+
+impl RotatingCache {
+ pub fn new(dim: usize, max_seq_len: usize) -> Self {
+ Self {
+ all_data: None,
+ dim,
+ offset: 0,
+ current_seq_len: 0,
+ max_seq_len,
+ }
+ }
+
+ pub fn offset(&self) -> usize {
+ self.offset
+ }
+
+ pub fn dim(&self) -> usize {
+ self.dim
+ }
+
+ pub fn current_seq_len(&self) -> usize {
+ self.current_seq_len
+ }
+
+ pub fn max_seq_len(&self) -> usize {
+ self.max_seq_len
+ }
+
+ pub fn all_data(&self) -> &Option<Tensor> {
+ &self.all_data
+ }
+
+ pub fn current_data(&self) -> Result<Option<Tensor>> {
+ let data = match self.all_data.as_ref() {
+ None => None,
+ Some(d) => {
+ if self.current_seq_len >= self.max_seq_len {
+ Some(d.clone())
+ } else {
+ Some(d.narrow(self.dim, 0, self.current_seq_len)?)
+ }
+ }
+ };
+ Ok(data)
+ }
+
+ pub fn reset(&mut self) {
+ self.offset = 0;
+ self.current_seq_len = 0;
+ self.all_data = None;
+ }
+
+ pub fn append(&mut self, src: &Tensor) -> Result<Tensor> {
+ let seq_len = src.dim(self.dim)?;
+ // This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
+ // self.all_data.get_or_insert_with.
+ if self.all_data.is_none() {
+ let mut shape = src.dims().to_vec();
+ shape[self.dim] = self.max_seq_len;
+ let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
+ self.all_data = Some(ad)
+ };
+ let ad = self.all_data.as_mut().unwrap();
+
+ self.current_seq_len += seq_len;
+ if seq_len >= self.max_seq_len {
+ let to_copy = src
+ .narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?
+ .contiguous()?;
+ ad.slice_set(&to_copy, self.dim, 0)?;
+ self.offset = 0;
+ // Here we return `src` rather than `ad` so that all the past can be used.
+ Ok(src.clone())
+ } else {
+ let rem_len = self.max_seq_len - self.offset;
+ if seq_len <= rem_len {
+ ad.slice_set(&src.contiguous()?, self.dim, self.offset)?;
+ self.offset = (self.offset + seq_len) % self.max_seq_len;
+ } else {
+ // We have to make two copies here as we go over the boundary of the cache.
+ if rem_len > 0 {
+ let src1 = src.narrow(self.dim, 0, rem_len)?.contiguous()?;
+ ad.slice_set(&src1, self.dim, self.offset)?;
+ }
+ let src2 = src
+ .narrow(self.dim, rem_len, seq_len - rem_len)?
+ .contiguous()?;
+ ad.slice_set(&src2, self.dim, 0)?;
+ self.offset = seq_len - rem_len;
+ }
+ if self.current_seq_len >= self.max_seq_len {
+ Ok(ad.clone())
+ } else {
+ Ok(ad.narrow(self.dim, 0, self.current_seq_len)?)
+ }
+ }
+ }
+
+ fn get_mask_abs(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
+ let context = self.max_seq_len;
+ let mask: Vec<_> = (0..size1)
+ .flat_map(|i| {
+ (0..size2).map(move |j| {
+ u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i)
+ })
+ })
+ .collect();
+ Tensor::from_slice(&mask, (size1, size2), device)
+ }
+
+ fn get_mask_rel(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
+ let context = self.max_seq_len;
+ let upd_offset = (self.offset + size1) % self.max_seq_len;
+ let mask: Vec<_> = (0..size1)
+ .flat_map(|pos_src| {
+ // The absolute position of the elements that will get added to the cache.
+ let pos_src = self.current_seq_len + pos_src;
+ (0..size2).map(move |pos_cache_rel| {
+ // The absolute position of the cache elements after the addition.
+ let pos_cache = self.current_seq_len + size1 + pos_cache_rel - upd_offset;
+ let pos_cache = if pos_cache_rel < upd_offset {
+ pos_cache
+ } else {
+ pos_cache - self.max_seq_len
+ };
+ u8::from(pos_cache > pos_src || pos_cache + context < pos_src)
+ })
+ })
+ .collect();
+ Tensor::from_slice(&mask, (size1, size2), device)
+ }
+
+ /// Returns the attn_mask to be applied *after* adding `seq_len` to the cache.
+ pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
+ let mask = if seq_len == 1 {
+ None
+ } else {
+ let mask = if seq_len < self.max_seq_len {
+ let cache_out_len = (self.current_seq_len + seq_len).min(self.max_seq_len);
+ self.get_mask_rel(seq_len, cache_out_len, device)?
+ } else {
+ self.get_mask_abs(seq_len, seq_len, device)?
+ };
+ Some(mask)
+ };
+ Ok(mask)
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct RotatingKvCache {
+ k: RotatingCache,
+ v: RotatingCache,
+}
+
+impl RotatingKvCache {
+ pub fn new(dim: usize, max_seq_len: usize) -> Self {
+ let k = RotatingCache::new(dim, max_seq_len);
+ let v = RotatingCache::new(dim, max_seq_len);
+ Self { k, v }
+ }
+
+ pub fn k_cache(&self) -> &RotatingCache {
+ &self.k
+ }
+
+ pub fn v_cache(&self) -> &RotatingCache {
+ &self.v
+ }
+
+ pub fn k_cache_mut(&mut self) -> &mut RotatingCache {
+ &mut self.k
+ }
+
+ pub fn v_cache_mut(&mut self) -> &mut RotatingCache {
+ &mut self.v
+ }
+
+ pub fn k(&self) -> Result<Option<Tensor>> {
+ self.k.current_data()
+ }
+
+ pub fn v(&self) -> Result<Option<Tensor>> {
+ self.v.current_data()
+ }
+
+ pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
+ let out_k = self.k.append(k)?;
+ let out_v = self.v.append(v)?;
+ Ok((out_k, out_v))
+ }
+
+ pub fn offset(&self) -> usize {
+ self.k.offset()
+ }
+
+ pub fn current_seq_len(&self) -> usize {
+ self.k.current_seq_len()
+ }
+
+ pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
+ self.k.attn_mask(seq_len, device)
+ }
+
+ pub fn reset(&mut self) {
+ self.k.reset();
+ self.v.reset();
+ }
+}
diff --git a/candle-nn/tests/kv_cache.rs b/candle-nn/tests/kv_cache.rs
new file mode 100644
index 00000000..b8d2ec48
--- /dev/null
+++ b/candle-nn/tests/kv_cache.rs
@@ -0,0 +1,110 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+#[cfg(feature = "accelerate")]
+extern crate accelerate_src;
+
+use candle::{Device, Result, Tensor};
+
+#[test]
+fn kv_cache() -> Result<()> {
+ let mut cache = candle_nn::kv_cache::Cache::new(0, 16);
+ for _ in [0, 1] {
+ assert_eq!(cache.current_seq_len(), 0);
+ let data = cache.current_data()?;
+ assert!(data.is_none());
+ let t = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
+ cache.append(&t)?;
+ let data = cache.current_data()?.unwrap();
+ assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3.]);
+ let t = Tensor::new(&[4f32], &Device::Cpu)?;
+ cache.append(&t)?;
+ let data = cache.current_data()?.unwrap();
+ assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3., 4.]);
+ let t = Tensor::new(&[0f32, 5., 6., 7.], &Device::Cpu)?;
+ cache.append(&t)?;
+ let data = cache.current_data()?.unwrap();
+ assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3., 4., 0., 5., 6., 7.]);
+ assert_eq!(cache.current_seq_len(), 8);
+ cache.reset();
+ }
+ Ok(())
+}
+
+#[test]
+fn rotating_kv_cache() -> Result<()> {
+ let mut cache = candle_nn::kv_cache::RotatingCache::new(0, 6);
+ for _ in [0, 1] {
+ assert_eq!(cache.offset(), 0);
+ assert_eq!(cache.current_seq_len(), 0);
+ let data = cache.current_data()?;
+ assert!(data.is_none());
+ let t = Tensor::new(&[1., 2., 3.], &Device::Cpu)?;
+ let data = cache.append(&t)?;
+ assert_eq!(data.to_vec1::<f64>()?, [1., 2., 3.]);
+ let t = Tensor::new(&[4.], &Device::Cpu)?;
+ let data = cache.append(&t)?;
+ assert_eq!(data.to_vec1::<f64>()?, [1., 2., 3., 4.]);
+ let t = Tensor::new(&[0., 5., 6., 7.], &Device::Cpu)?;
+ let data = cache.append(&t)?;
+ assert_eq!(data.to_vec1::<f64>()?, [6., 7., 3., 4., 0., 5.]);
+ assert_eq!(cache.current_seq_len(), 8);
+ assert_eq!(cache.offset(), 2);
+
+ let t = Tensor::new(&[8.], &Device::Cpu)?;
+ let data = cache.append(&t)?;
+ assert_eq!(data.to_vec1::<f64>()?, [6., 7., 8., 4., 0., 5.]);
+ assert_eq!(cache.current_seq_len(), 9);
+ assert_eq!(cache.offset(), 3);
+
+ let t = Tensor::new(&[9., 10., 11.], &Device::Cpu)?;
+ let data = cache.append(&t)?;
+ assert_eq!(data.to_vec1::<f64>()?, [6., 7., 8., 9., 10., 11.]);
+ assert_eq!(cache.current_seq_len(), 12);
+ assert_eq!(cache.offset(), 0);
+
+ let t = Tensor::new(&[12.], &Device::Cpu)?;
+ let data = cache.append(&t)?;
+ assert_eq!(data.to_vec1::<f64>()?, [12., 7., 8., 9., 10., 11.]);
+ assert_eq!(cache.current_seq_len(), 13);
+ assert_eq!(cache.offset(), 1);
+
+ let mask = cache.attn_mask(2, &Device::Cpu)?.unwrap();
+ assert_eq!(
+ mask.to_vec2::<u8>()?,
+ &[[0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
+ );
+ let mask = cache.attn_mask(3, &Device::Cpu)?.unwrap();
+ assert_eq!(
+ mask.to_vec2::<u8>()?,
+ &[[0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0]],
+ );
+ let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?;
+ let data = cache.append(&t)?;
+ assert_eq!(data.to_vec1::<f64>()?, [0., 1., 2., 3., 4., 5., 6., 7., 8.]);
+ assert_eq!(cache.current_seq_len(), 22);
+ assert_eq!(cache.offset(), 0);
+
+ let mask = cache.attn_mask(1, &Device::Cpu)?;
+ assert!(mask.is_none());
+ let mask = cache.attn_mask(2, &Device::Cpu)?.unwrap();
+ assert_eq!(
+ mask.to_vec2::<u8>()?,
+ &[[0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
+ );
+ let mask = cache.attn_mask(3, &Device::Cpu)?.unwrap();
+ assert_eq!(
+ mask.to_vec2::<u8>()?,
+ &[[0, 1, 1, 0, 0, 0], [0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
+ );
+ let t = Tensor::new(&[42.], &Device::Cpu)?;
+
+ let data = cache.append(&t)?;
+ assert_eq!(data.to_vec1::<f64>()?, [42., 4., 5., 6., 7., 8.]);
+ assert_eq!(cache.current_seq_len(), 23);
+ assert_eq!(cache.offset(), 1);
+
+ cache.reset();
+ }
+ Ok(())
+}