diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-09-23 13:14:32 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-23 13:14:32 +0200 |
commit | d01207dbf3fb0ad614e7915c8f5706fbc09902fb (patch) | |
tree | bd8303c3d7acd097777485f220acd14aac73bd83 /candle-nn/tests | |
parent | 8097559c1a293d26cf9cd65d92d9cb4696197c2e (diff) | |
download | candle-d01207dbf3fb0ad614e7915c8f5706fbc09902fb.tar.gz candle-d01207dbf3fb0ad614e7915c8f5706fbc09902fb.tar.bz2 candle-d01207dbf3fb0ad614e7915c8f5706fbc09902fb.zip |
Add a RotatingKVCache. (#2493)
* Add a RotatingKVCache.
* Add some KvCache tests.
* Test the reset too.
* More kv-cache testing.
* More tests for the rotating kv-cache.
* Improve the api for the rotating cache so that the whole src tensor gets returned when it's overlarge.
* Handle contiguity + bugfix + use in mimi.
* Add a way to test the mimi streaming mode.
* Mimi streaming fixes.
* More rotating kv-cache.
* Fix the attn mask generation.
* Handle the abs case.
* Add some tests for the generated mask.
Diffstat (limited to 'candle-nn/tests')
-rw-r--r-- | candle-nn/tests/kv_cache.rs | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/candle-nn/tests/kv_cache.rs b/candle-nn/tests/kv_cache.rs new file mode 100644 index 00000000..b8d2ec48 --- /dev/null +++ b/candle-nn/tests/kv_cache.rs @@ -0,0 +1,110 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use candle::{Device, Result, Tensor}; + +#[test] +fn kv_cache() -> Result<()> { + let mut cache = candle_nn::kv_cache::Cache::new(0, 16); + for _ in [0, 1] { + assert_eq!(cache.current_seq_len(), 0); + let data = cache.current_data()?; + assert!(data.is_none()); + let t = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?; + cache.append(&t)?; + let data = cache.current_data()?.unwrap(); + assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3.]); + let t = Tensor::new(&[4f32], &Device::Cpu)?; + cache.append(&t)?; + let data = cache.current_data()?.unwrap(); + assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3., 4.]); + let t = Tensor::new(&[0f32, 5., 6., 7.], &Device::Cpu)?; + cache.append(&t)?; + let data = cache.current_data()?.unwrap(); + assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3., 4., 0., 5., 6., 7.]); + assert_eq!(cache.current_seq_len(), 8); + cache.reset(); + } + Ok(()) +} + +#[test] +fn rotating_kv_cache() -> Result<()> { + let mut cache = candle_nn::kv_cache::RotatingCache::new(0, 6); + for _ in [0, 1] { + assert_eq!(cache.offset(), 0); + assert_eq!(cache.current_seq_len(), 0); + let data = cache.current_data()?; + assert!(data.is_none()); + let t = Tensor::new(&[1., 2., 3.], &Device::Cpu)?; + let data = cache.append(&t)?; + assert_eq!(data.to_vec1::<f64>()?, [1., 2., 3.]); + let t = Tensor::new(&[4.], &Device::Cpu)?; + let data = cache.append(&t)?; + assert_eq!(data.to_vec1::<f64>()?, [1., 2., 3., 4.]); + let t = Tensor::new(&[0., 5., 6., 7.], &Device::Cpu)?; + let data = cache.append(&t)?; + assert_eq!(data.to_vec1::<f64>()?, [6., 7., 3., 4., 0., 5.]); + assert_eq!(cache.current_seq_len(), 8); + assert_eq!(cache.offset(), 2); + + let t = Tensor::new(&[8.], &Device::Cpu)?; + let data = cache.append(&t)?; + assert_eq!(data.to_vec1::<f64>()?, [6., 7., 8., 4., 0., 5.]); + assert_eq!(cache.current_seq_len(), 9); + assert_eq!(cache.offset(), 3); + + let t = Tensor::new(&[9., 10., 11.], &Device::Cpu)?; + let data = cache.append(&t)?; + assert_eq!(data.to_vec1::<f64>()?, [6., 7., 8., 9., 10., 11.]); + assert_eq!(cache.current_seq_len(), 12); + assert_eq!(cache.offset(), 0); + + let t = Tensor::new(&[12.], &Device::Cpu)?; + let data = cache.append(&t)?; + assert_eq!(data.to_vec1::<f64>()?, [12., 7., 8., 9., 10., 11.]); + assert_eq!(cache.current_seq_len(), 13); + assert_eq!(cache.offset(), 1); + + let mask = cache.attn_mask(2, &Device::Cpu)?.unwrap(); + assert_eq!( + mask.to_vec2::<u8>()?, + &[[0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0]] + ); + let mask = cache.attn_mask(3, &Device::Cpu)?.unwrap(); + assert_eq!( + mask.to_vec2::<u8>()?, + &[[0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0]], + ); + let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?; + let data = cache.append(&t)?; + assert_eq!(data.to_vec1::<f64>()?, [0., 1., 2., 3., 4., 5., 6., 7., 8.]); + assert_eq!(cache.current_seq_len(), 22); + assert_eq!(cache.offset(), 0); + + let mask = cache.attn_mask(1, &Device::Cpu)?; + assert!(mask.is_none()); + let mask = cache.attn_mask(2, &Device::Cpu)?.unwrap(); + assert_eq!( + mask.to_vec2::<u8>()?, + &[[0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]] + ); + let mask = cache.attn_mask(3, &Device::Cpu)?.unwrap(); + assert_eq!( + mask.to_vec2::<u8>()?, + &[[0, 1, 1, 0, 0, 0], [0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0]] + ); + let t = Tensor::new(&[42.], &Device::Cpu)?; + + let data = cache.append(&t)?; + assert_eq!(data.to_vec1::<f64>()?, [42., 4., 5., 6., 7., 8.]); + assert_eq!(cache.current_seq_len(), 23); + assert_eq!(cache.offset(), 1); + + cache.reset(); + } + Ok(()) +} |