summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-05-18 15:58:18 +0200
committerGitHub <noreply@github.com>2024-05-18 15:58:18 +0200
commit01545f73038cb8c90426214ddf4bcedd59e291e8 (patch)
tree2f2af0905b404bd42ee22962f94ea812e3caaa9e /candle-nn
parent349c3e806a15399df8289c41b2e24c3fa24b6d84 (diff)
downloadcandle-01545f73038cb8c90426214ddf4bcedd59e291e8.tar.gz
candle-01545f73038cb8c90426214ddf4bcedd59e291e8.tar.bz2
candle-01545f73038cb8c90426214ddf4bcedd59e291e8.zip
Add a slice_set op. (#2193)
* Add a slice_set op. * Add some testing. * Add the dedicated kv-cache module. * Derive debug and clone. * Expose more kv-cache functions. * Return the current data when appending. * Use the new cache in the quantized phi3 model.
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/kv_cache.rs101
-rw-r--r--candle-nn/src/lib.rs1
2 files changed, 102 insertions, 0 deletions
diff --git a/candle-nn/src/kv_cache.rs b/candle-nn/src/kv_cache.rs
new file mode 100644
index 00000000..684053dc
--- /dev/null
+++ b/candle-nn/src/kv_cache.rs
@@ -0,0 +1,101 @@
+use candle::{DType, Device, Result, Shape, Tensor};
+
+#[derive(Debug, Clone)]
+pub struct Cache {
+ all_data: Tensor,
+ dim: usize,
+ current_seq_len: usize,
+ max_seq_len: usize,
+}
+
+impl Cache {
+ pub fn new<S: Into<Shape>, D: candle::shape::Dim>(
+ dim: D,
+ shape: S,
+ dtype: DType,
+ dev: &Device,
+ ) -> Result<Self> {
+ let shape = shape.into();
+ let dim = dim.to_index(&shape, "kv-cache")?;
+ let max_seq_len = shape.dims()[dim];
+ let all_data = Tensor::zeros(shape, dtype, dev)?;
+ Ok(Self {
+ all_data,
+ dim,
+ current_seq_len: 0,
+ max_seq_len,
+ })
+ }
+
+ pub fn dim(&self) -> usize {
+ self.dim
+ }
+
+ pub fn current_seq_len(&self) -> usize {
+ self.current_seq_len
+ }
+
+ pub fn max_seq_len(&self) -> usize {
+ self.max_seq_len
+ }
+
+ pub fn all_data(&self) -> &Tensor {
+ &self.all_data
+ }
+
+ pub fn current_data(&self) -> Result<Tensor> {
+ self.all_data.narrow(self.dim, 0, self.current_seq_len)
+ }
+
+ pub fn append(&mut self, src: &Tensor) -> Result<()> {
+ let seq_len = src.dim(self.dim)?;
+ if self.current_seq_len + seq_len > self.max_seq_len {
+ candle::bail!(
+ "kv-cache: above max-seq-len {}+{seq_len}>{}",
+ self.current_seq_len,
+ self.max_seq_len
+ )
+ }
+ self.all_data
+ .slice_set(src, self.dim, self.current_seq_len)?;
+ self.current_seq_len += seq_len;
+ Ok(())
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct KvCache {
+ k: Cache,
+ v: Cache,
+}
+
+impl KvCache {
+ pub fn new<S: Into<Shape>, D: candle::shape::Dim>(
+ dim: D,
+ shape: S,
+ dtype: DType,
+ dev: &Device,
+ ) -> Result<Self> {
+ let shape = shape.into();
+ let dim = dim.to_index(&shape, "kv-cache")?;
+ let k = Cache::new(dim, &shape, dtype, dev)?;
+ let v = Cache::new(dim, &shape, dtype, dev)?;
+ Ok(Self { k, v })
+ }
+
+ pub fn k(&self) -> Result<Tensor> {
+ self.k.current_data()
+ }
+
+ pub fn v(&self) -> Result<Tensor> {
+ self.v.current_data()
+ }
+
+ pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
+ self.k.append(k)?;
+ self.v.append(v)?;
+ let k = self.k.current_data()?;
+ let v = self.v.current_data()?;
+ Ok((k, v))
+ }
+}
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index 5c0fbb37..fcac5830 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -6,6 +6,7 @@ pub mod encoding;
pub mod func;
pub mod group_norm;
pub mod init;
+pub mod kv_cache;
pub mod layer_norm;
pub mod linear;
pub mod loss;