summaryrefslogtreecommitdiff
path: root/candle-examples/examples/falcon
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-07 12:56:44 +0100
committerGitHub <noreply@github.com>2023-07-07 12:56:44 +0100
commit05ff1cff669b48deebc0b7f9cf056589c7a448e9 (patch)
treef1d4f7729981cbedc9db80db3d163ff19e273b67 /candle-examples/examples/falcon
parent65937612d068ce98bbcbdbc3211f6dd8f293d7cf (diff)
downloadcandle-05ff1cff669b48deebc0b7f9cf056589c7a448e9.tar.gz
candle-05ff1cff669b48deebc0b7f9cf056589c7a448e9.tar.bz2
candle-05ff1cff669b48deebc0b7f9cf056589c7a448e9.zip
Add some caching to the causal mask. (#103)
Diffstat (limited to 'candle-examples/examples/falcon')
-rw-r--r--candle-examples/examples/falcon/model.rs12
1 files changed, 10 insertions, 2 deletions
diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs
index 82e89841..4aaa58ea 100644
--- a/candle-examples/examples/falcon/model.rs
+++ b/candle-examples/examples/falcon/model.rs
@@ -282,6 +282,7 @@ fn rotate_half(x: &Tensor) -> Result<Tensor> {
#[derive(Debug)]
struct FalconRotaryEmbedding {
inv_freq: Tensor,
+ cache: Option<(usize, Tensor, Tensor)>,
}
impl FalconRotaryEmbedding {
@@ -292,7 +293,8 @@ impl FalconRotaryEmbedding {
.map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
.collect();
let inv_freq = Tensor::new(inv_freq.as_slice(), &vb.device)?;
- Ok(Self { inv_freq })
+ let cache = None;
+ Ok(Self { inv_freq, cache })
}
fn cos_sin(
@@ -301,7 +303,12 @@ impl FalconRotaryEmbedding {
device: &Device,
dtype: DType,
) -> Result<(Tensor, Tensor)> {
- // TODO: Add the cache.
+ match &self.cache {
+ Some((s, cos, sin)) if *s == seq_len => {
+ return Ok((cos.clone(), sin.clone()));
+ }
+ _ => {}
+ }
let t: Vec<_> = (0..seq_len).map(|c| c as u32).collect();
let t = Tensor::new(t.as_slice(), device)?.to_dtype(dtype)?;
let inv_freq = self.inv_freq.to_dtype(dtype)?;
@@ -309,6 +316,7 @@ impl FalconRotaryEmbedding {
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
let cos = emb.cos()?;
let sin = emb.sin()?;
+ self.cache = Some((seq_len, cos.clone(), sin.clone()));
Ok((cos, sin))
}