summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-05-23 16:21:47 +0200
committerGitHub <noreply@github.com>2024-05-23 16:21:47 +0200
commit31cf64147b9ab4a3d68849bef0ea59bdb0c113d6 (patch)
tree7fb29f43e2fa5527e761e287a930d6cd5e40198b
parent77ea479a1847d909ca5e4f27a36f5c8e302cd529 (diff)
downloadcandle-31cf64147b9ab4a3d68849bef0ea59bdb0c113d6.tar.gz
candle-31cf64147b9ab4a3d68849bef0ea59bdb0c113d6.tar.bz2
candle-31cf64147b9ab4a3d68849bef0ea59bdb0c113d6.zip
Add a couple kv-cache helper functions. (#2206)
-rw-r--r--candle-nn/src/kv_cache.rs29
1 files changed, 29 insertions, 0 deletions
diff --git a/candle-nn/src/kv_cache.rs b/candle-nn/src/kv_cache.rs
index 684053dc..10e9fe5a 100644
--- a/candle-nn/src/kv_cache.rs
+++ b/candle-nn/src/kv_cache.rs
@@ -47,6 +47,10 @@ impl Cache {
self.all_data.narrow(self.dim, 0, self.current_seq_len)
}
+ pub fn reset(&mut self) {
+ self.current_seq_len = 0
+ }
+
pub fn append(&mut self, src: &Tensor) -> Result<()> {
let seq_len = src.dim(self.dim)?;
if self.current_seq_len + seq_len > self.max_seq_len {
@@ -83,6 +87,22 @@ impl KvCache {
Ok(Self { k, v })
}
+ pub fn k_cache(&self) -> &Cache {
+ &self.k
+ }
+
+ pub fn v_cache(&self) -> &Cache {
+ &self.v
+ }
+
+ pub fn k_cache_mut(&mut self) -> &mut Cache {
+ &mut self.k
+ }
+
+ pub fn v_cache_mut(&mut self) -> &mut Cache {
+ &mut self.v
+ }
+
pub fn k(&self) -> Result<Tensor> {
self.k.current_data()
}
@@ -98,4 +118,13 @@ impl KvCache {
let v = self.v.current_data()?;
Ok((k, v))
}
+
+ pub fn current_seq_len(&self) -> usize {
+ self.k.current_seq_len()
+ }
+
+ pub fn reset(&mut self) {
+ self.k.reset();
+ self.v.reset();
+ }
}