diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-07 12:56:44 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-07 12:56:44 +0100 |
commit | 05ff1cff669b48deebc0b7f9cf056589c7a448e9 (patch) | |
tree | f1d4f7729981cbedc9db80db3d163ff19e273b67 /candle-examples/examples/falcon | |
parent | 65937612d068ce98bbcbdbc3211f6dd8f293d7cf (diff) | |
download | candle-05ff1cff669b48deebc0b7f9cf056589c7a448e9.tar.gz candle-05ff1cff669b48deebc0b7f9cf056589c7a448e9.tar.bz2 candle-05ff1cff669b48deebc0b7f9cf056589c7a448e9.zip |
Add some caching to the causal mask. (#103)
Diffstat (limited to 'candle-examples/examples/falcon')
-rw-r--r-- | candle-examples/examples/falcon/model.rs | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs index 82e89841..4aaa58ea 100644 --- a/candle-examples/examples/falcon/model.rs +++ b/candle-examples/examples/falcon/model.rs @@ -282,6 +282,7 @@ fn rotate_half(x: &Tensor) -> Result<Tensor> { #[derive(Debug)] struct FalconRotaryEmbedding { inv_freq: Tensor, + cache: Option<(usize, Tensor, Tensor)>, } impl FalconRotaryEmbedding { @@ -292,7 +293,8 @@ impl FalconRotaryEmbedding { .map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32)) .collect(); let inv_freq = Tensor::new(inv_freq.as_slice(), &vb.device)?; - Ok(Self { inv_freq }) + let cache = None; + Ok(Self { inv_freq, cache }) } fn cos_sin( @@ -301,7 +303,12 @@ impl FalconRotaryEmbedding { device: &Device, dtype: DType, ) -> Result<(Tensor, Tensor)> { - // TODO: Add the cache. + match &self.cache { + Some((s, cos, sin)) if *s == seq_len => { + return Ok((cos.clone(), sin.clone())); + } + _ => {} + } let t: Vec<_> = (0..seq_len).map(|c| c as u32).collect(); let t = Tensor::new(t.as_slice(), device)?.to_dtype(dtype)?; let inv_freq = self.inv_freq.to_dtype(dtype)?; @@ -309,6 +316,7 @@ impl FalconRotaryEmbedding { let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; let cos = emb.cos()?; let sin = emb.sin()?; + self.cache = Some((seq_len, cos.clone(), sin.clone())); Ok((cos, sin)) } |