summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama2-c/model.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-24 10:56:08 +0100
committerGitHub <noreply@github.com>2023-07-24 10:56:08 +0100
commit550a13a5472fd3aa3975c2453eff4bff6ac1d0bd (patch)
tree519dae5278086ebaf08ca98fa359047c5b13313b /candle-examples/examples/llama2-c/model.rs
parent35b65fed8847646bf3f759711d0028b9befa8970 (diff)
downloadcandle-550a13a5472fd3aa3975c2453eff4bff6ac1d0bd.tar.gz
candle-550a13a5472fd3aa3975c2453eff4bff6ac1d0bd.tar.bz2
candle-550a13a5472fd3aa3975c2453eff4bff6ac1d0bd.zip
Use the binary decoder for llama2.c. (#230)
* Use the binary decoder for llama2.c. * Add the temperature. * Formatting tweak. * Fix the rotary embeddings.
Diffstat (limited to 'candle-examples/examples/llama2-c/model.rs')
-rw-r--r--candle-examples/examples/llama2-c/model.rs19
1 files changed, 11 insertions, 8 deletions
diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs
index 2fb4b444..13f939db 100644
--- a/candle-examples/examples/llama2-c/model.rs
+++ b/candle-examples/examples/llama2-c/model.rs
@@ -30,12 +30,14 @@ impl Cache {
pub fn new(use_kv_cache: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let freq_cis_real = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_real")?;
let freq_cis_imag = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_imag")?;
+ let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
+ let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
Ok(Self {
masks: Arc::new(Mutex::new(HashMap::new())),
use_kv_cache,
kvs: Arc::new(Mutex::new(vec![None; cfg.n_layers])),
- cos: freq_cis_real,
- sin: freq_cis_imag,
+ cos,
+ sin,
device: vb.device().clone(),
})
}
@@ -110,16 +112,17 @@ struct CausalSelfAttention {
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
- let (b_sz, _, seq_len, n_embd) = x.dims4()?;
+ let (b_sz, seq_len, h, n_embd) = x.dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
- let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2))?;
- let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2))?;
- let x0 = x.narrow(D::Minus1, 0, n_embd / 2)?;
- let x1 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?;
+ let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
+ let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
+ let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
+ let x0 = x.narrow(D::Minus1, 0, 1)?;
+ let x1 = x.narrow(D::Minus1, 1, 1)?;
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
- let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?;
+ let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
Ok(rope)
}