summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-05-23 17:07:21 +0200
committerGitHub <noreply@github.com>2024-05-23 17:07:21 +0200
commit45e235a7473d473df5c1e50f55504a97e28be822 (patch)
tree6e8518249f1bbbc634431327b96d24f3e270ebc5 /candle-nn
parent31cf64147b9ab4a3d68849bef0ea59bdb0c113d6 (diff)
downloadcandle-45e235a7473d473df5c1e50f55504a97e28be822.tar.gz
candle-45e235a7473d473df5c1e50f55504a97e28be822.tar.bz2
candle-45e235a7473d473df5c1e50f55504a97e28be822.zip
Simplify the KvCache api. (#2207)
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/kv_cache.rs89
1 files changed, 53 insertions, 36 deletions
diff --git a/candle-nn/src/kv_cache.rs b/candle-nn/src/kv_cache.rs
index 10e9fe5a..eb5dbfdb 100644
--- a/candle-nn/src/kv_cache.rs
+++ b/candle-nn/src/kv_cache.rs
@@ -1,30 +1,25 @@
-use candle::{DType, Device, Result, Shape, Tensor};
+use candle::{Result, Tensor};
#[derive(Debug, Clone)]
pub struct Cache {
- all_data: Tensor,
+ // all_data is an option on a Tensor, this makes it possible to only create the actual tensor
+ // on the first call where the batch size is easily known.
+ // Also this makes it safe to clone a KvCache that has been reseted (as in it will not share
+ // its internal state with the cloned instance).
+ all_data: Option<Tensor>,
dim: usize,
current_seq_len: usize,
max_seq_len: usize,
}
impl Cache {
- pub fn new<S: Into<Shape>, D: candle::shape::Dim>(
- dim: D,
- shape: S,
- dtype: DType,
- dev: &Device,
- ) -> Result<Self> {
- let shape = shape.into();
- let dim = dim.to_index(&shape, "kv-cache")?;
- let max_seq_len = shape.dims()[dim];
- let all_data = Tensor::zeros(shape, dtype, dev)?;
- Ok(Self {
- all_data,
+ pub fn new(dim: usize, max_seq_len: usize) -> Self {
+ Self {
+ all_data: None,
dim,
current_seq_len: 0,
max_seq_len,
- })
+ }
}
pub fn dim(&self) -> usize {
@@ -39,20 +34,34 @@ impl Cache {
self.max_seq_len
}
- pub fn all_data(&self) -> &Tensor {
+ pub fn all_data(&self) -> &Option<Tensor> {
&self.all_data
}
- pub fn current_data(&self) -> Result<Tensor> {
- self.all_data.narrow(self.dim, 0, self.current_seq_len)
+ pub fn current_data(&self) -> Result<Option<Tensor>> {
+ let data = match self.all_data.as_ref() {
+ None => None,
+ Some(d) => Some(d.narrow(self.dim, 0, self.current_seq_len)?),
+ };
+ Ok(data)
}
pub fn reset(&mut self) {
- self.current_seq_len = 0
+ self.current_seq_len = 0;
+ self.all_data = None;
}
pub fn append(&mut self, src: &Tensor) -> Result<()> {
let seq_len = src.dim(self.dim)?;
+ // This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
+ // self.all_data.get_or_insert_with.
+ if self.all_data.is_none() {
+ let mut shape = src.dims().to_vec();
+ shape[self.dim] = self.max_seq_len;
+ let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
+ self.all_data = Some(ad)
+ };
+ let ad = self.all_data.as_mut().unwrap();
if self.current_seq_len + seq_len > self.max_seq_len {
candle::bail!(
"kv-cache: above max-seq-len {}+{seq_len}>{}",
@@ -60,8 +69,7 @@ impl Cache {
self.max_seq_len
)
}
- self.all_data
- .slice_set(src, self.dim, self.current_seq_len)?;
+ ad.slice_set(src, self.dim, self.current_seq_len)?;
self.current_seq_len += seq_len;
Ok(())
}
@@ -74,17 +82,10 @@ pub struct KvCache {
}
impl KvCache {
- pub fn new<S: Into<Shape>, D: candle::shape::Dim>(
- dim: D,
- shape: S,
- dtype: DType,
- dev: &Device,
- ) -> Result<Self> {
- let shape = shape.into();
- let dim = dim.to_index(&shape, "kv-cache")?;
- let k = Cache::new(dim, &shape, dtype, dev)?;
- let v = Cache::new(dim, &shape, dtype, dev)?;
- Ok(Self { k, v })
+ pub fn new(dim: usize, max_seq_len: usize) -> Self {
+ let k = Cache::new(dim, max_seq_len);
+ let v = Cache::new(dim, max_seq_len);
+ Self { k, v }
}
pub fn k_cache(&self) -> &Cache {
@@ -103,19 +104,35 @@ impl KvCache {
&mut self.v
}
- pub fn k(&self) -> Result<Tensor> {
+ pub fn k(&self) -> Result<Option<Tensor>> {
self.k.current_data()
}
- pub fn v(&self) -> Result<Tensor> {
+ pub fn v(&self) -> Result<Option<Tensor>> {
self.v.current_data()
}
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
self.k.append(k)?;
self.v.append(v)?;
- let k = self.k.current_data()?;
- let v = self.v.current_data()?;
+ let out_k = self.k.current_data()?;
+ let out_v = self.v.current_data()?;
+ let k = match out_k {
+ None => {
+ let mut shape = k.dims().to_vec();
+ shape[self.k.dim] = 0;
+ Tensor::zeros(shape, k.dtype(), k.device())?
+ }
+ Some(k) => k,
+ };
+ let v = match out_v {
+ None => {
+ let mut shape = v.dims().to_vec();
+ shape[self.k.dim] = 0;
+ Tensor::zeros(shape, v.dtype(), v.device())?
+ }
+ Some(v) => v,
+ };
Ok((k, v))
}