use candle::{Device, Result, Tensor}; #[derive(Debug, Clone)] pub struct Cache { // all_data is an option on a Tensor, this makes it possible to only create the actual tensor // on the first call where the batch size is easily known. // Also this makes it safe to clone a KvCache that has been reseted (as in it will not share // its internal state with the cloned instance). all_data: Option, dim: usize, current_seq_len: usize, max_seq_len: usize, } impl Cache { pub fn new(dim: usize, max_seq_len: usize) -> Self { Self { all_data: None, dim, current_seq_len: 0, max_seq_len, } } pub fn dim(&self) -> usize { self.dim } pub fn current_seq_len(&self) -> usize { self.current_seq_len } pub fn max_seq_len(&self) -> usize { self.max_seq_len } pub fn all_data(&self) -> &Option { &self.all_data } pub fn current_data(&self) -> Result> { let data = match self.all_data.as_ref() { None => None, Some(d) => Some(d.narrow(self.dim, 0, self.current_seq_len)?), }; Ok(data) } pub fn reset(&mut self) { self.current_seq_len = 0; self.all_data = None; } pub fn append(&mut self, src: &Tensor) -> Result<()> { let seq_len = src.dim(self.dim)?; // This doesn't seem very idiomatic but because the creation can fail, it's tricky to use // self.all_data.get_or_insert_with. if self.all_data.is_none() { let mut shape = src.dims().to_vec(); shape[self.dim] = self.max_seq_len; let ad = Tensor::zeros(shape, src.dtype(), src.device())?; self.all_data = Some(ad) }; let ad = self.all_data.as_mut().unwrap(); if self.current_seq_len + seq_len > self.max_seq_len { candle::bail!( "kv-cache: above max-seq-len {}+{seq_len}>{}", self.current_seq_len, self.max_seq_len ) } ad.slice_set(src, self.dim, self.current_seq_len)?; self.current_seq_len += seq_len; Ok(()) } } #[derive(Debug, Clone)] pub struct KvCache { k: Cache, v: Cache, } impl KvCache { pub fn new(dim: usize, max_seq_len: usize) -> Self { let k = Cache::new(dim, max_seq_len); let v = Cache::new(dim, max_seq_len); Self { k, v } } pub fn k_cache(&self) -> &Cache { &self.k } pub fn v_cache(&self) -> &Cache { &self.v } pub fn k_cache_mut(&mut self) -> &mut Cache { &mut self.k } pub fn v_cache_mut(&mut self) -> &mut Cache { &mut self.v } pub fn k(&self) -> Result> { self.k.current_data() } pub fn v(&self) -> Result> { self.v.current_data() } pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { self.k.append(k)?; self.v.append(v)?; let out_k = self.k.current_data()?; let out_v = self.v.current_data()?; let k = match out_k { None => { let mut shape = k.dims().to_vec(); shape[self.k.dim] = 0; Tensor::zeros(shape, k.dtype(), k.device())? } Some(k) => k, }; let v = match out_v { None => { let mut shape = v.dims().to_vec(); shape[self.k.dim] = 0; Tensor::zeros(shape, v.dtype(), v.device())? } Some(v) => v, }; Ok((k, v)) } pub fn current_seq_len(&self) -> usize { self.k.current_seq_len() } pub fn reset(&mut self) { self.k.reset(); self.v.reset(); } } #[derive(Debug, Clone)] pub struct RotatingCache { all_data: Option, dim: usize, // `offset` is the current write index in the buffer offset: usize, // The total size of the sequence seen so far. current_seq_len: usize, // max_seq_len is the size of the rotating buffer, it is actually allowed for the full // sequence to grow past this limit. max_seq_len: usize, } impl RotatingCache { pub fn new(dim: usize, max_seq_len: usize) -> Self { Self { all_data: None, dim, offset: 0, current_seq_len: 0, max_seq_len, } } pub fn offset(&self) -> usize { self.offset } pub fn dim(&self) -> usize { self.dim } pub fn current_seq_len(&self) -> usize { self.current_seq_len } pub fn max_seq_len(&self) -> usize { self.max_seq_len } pub fn all_data(&self) -> &Option { &self.all_data } pub fn current_data(&self) -> Result> { let data = match self.all_data.as_ref() { None => None, Some(d) => { if self.current_seq_len >= self.max_seq_len { Some(d.clone()) } else { Some(d.narrow(self.dim, 0, self.current_seq_len)?) } } }; Ok(data) } pub fn reset(&mut self) { self.offset = 0; self.current_seq_len = 0; self.all_data = None; } pub fn append(&mut self, src: &Tensor) -> Result { let seq_len = src.dim(self.dim)?; // This doesn't seem very idiomatic but because the creation can fail, it's tricky to use // self.all_data.get_or_insert_with. if self.all_data.is_none() { let mut shape = src.dims().to_vec(); shape[self.dim] = self.max_seq_len; let ad = Tensor::zeros(shape, src.dtype(), src.device())?; self.all_data = Some(ad) }; let ad = self.all_data.as_mut().unwrap(); self.current_seq_len += seq_len; if seq_len >= self.max_seq_len { let to_copy = src .narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)? .contiguous()?; ad.slice_set(&to_copy, self.dim, 0)?; self.offset = 0; // Here we return `src` rather than `ad` so that all the past can be used. Ok(src.clone()) } else { let rem_len = self.max_seq_len - self.offset; if seq_len <= rem_len { ad.slice_set(&src.contiguous()?, self.dim, self.offset)?; self.offset = (self.offset + seq_len) % self.max_seq_len; } else { // We have to make two copies here as we go over the boundary of the cache. if rem_len > 0 { let src1 = src.narrow(self.dim, 0, rem_len)?.contiguous()?; ad.slice_set(&src1, self.dim, self.offset)?; } let src2 = src .narrow(self.dim, rem_len, seq_len - rem_len)? .contiguous()?; ad.slice_set(&src2, self.dim, 0)?; self.offset = seq_len - rem_len; } if self.current_seq_len >= self.max_seq_len { Ok(ad.clone()) } else { Ok(ad.narrow(self.dim, 0, self.current_seq_len)?) } } } fn get_mask_abs(&self, size1: usize, size2: usize, device: &Device) -> Result { let context = self.max_seq_len; let mask: Vec<_> = (0..size1) .flat_map(|i| { (0..size2).map(move |j| { u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i) }) }) .collect(); Tensor::from_slice(&mask, (size1, size2), device) } fn get_mask_rel(&self, size1: usize, size2: usize, device: &Device) -> Result { let context = self.max_seq_len; let upd_offset = (self.offset + size1) % self.max_seq_len; let mask: Vec<_> = (0..size1) .flat_map(|pos_src| { // The absolute position of the elements that will get added to the cache. let pos_src = self.current_seq_len + pos_src; (0..size2).map(move |pos_cache_rel| { // The absolute position of the cache elements after the addition. let pos_cache = self.current_seq_len + size1 + pos_cache_rel - upd_offset; let pos_cache = if pos_cache_rel < upd_offset { pos_cache } else { pos_cache - self.max_seq_len }; u8::from(pos_cache > pos_src || pos_cache + context < pos_src) }) }) .collect(); Tensor::from_slice(&mask, (size1, size2), device) } /// Returns the attn_mask to be applied *after* adding `seq_len` to the cache. pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result> { let mask = if seq_len == 1 { None } else { let mask = if seq_len < self.max_seq_len { let cache_out_len = (self.current_seq_len + seq_len).min(self.max_seq_len); self.get_mask_rel(seq_len, cache_out_len, device)? } else { self.get_mask_abs(seq_len, seq_len, device)? }; Some(mask) }; Ok(mask) } } #[derive(Debug, Clone)] pub struct RotatingKvCache { k: RotatingCache, v: RotatingCache, } impl RotatingKvCache { pub fn new(dim: usize, max_seq_len: usize) -> Self { let k = RotatingCache::new(dim, max_seq_len); let v = RotatingCache::new(dim, max_seq_len); Self { k, v } } pub fn k_cache(&self) -> &RotatingCache { &self.k } pub fn v_cache(&self) -> &RotatingCache { &self.v } pub fn k_cache_mut(&mut self) -> &mut RotatingCache { &mut self.k } pub fn v_cache_mut(&mut self) -> &mut RotatingCache { &mut self.v } pub fn k(&self) -> Result> { self.k.current_data() } pub fn v(&self) -> Result> { self.v.current_data() } pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { let out_k = self.k.append(k)?; let out_v = self.v.append(v)?; Ok((out_k, out_v)) } pub fn offset(&self) -> usize { self.k.offset() } pub fn current_seq_len(&self) -> usize { self.k.current_seq_len() } pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result> { self.k.attn_mask(seq_len, device) } pub fn reset(&mut self) { self.k.reset(); self.v.reset(); } }