diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-24 10:56:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-24 10:56:08 +0100 |
commit | 550a13a5472fd3aa3975c2453eff4bff6ac1d0bd (patch) | |
tree | 519dae5278086ebaf08ca98fa359047c5b13313b /candle-examples/examples/llama2-c/model.rs | |
parent | 35b65fed8847646bf3f759711d0028b9befa8970 (diff) | |
download | candle-550a13a5472fd3aa3975c2453eff4bff6ac1d0bd.tar.gz candle-550a13a5472fd3aa3975c2453eff4bff6ac1d0bd.tar.bz2 candle-550a13a5472fd3aa3975c2453eff4bff6ac1d0bd.zip |
Use the binary decoder for llama2.c. (#230)
* Use the binary decoder for llama2.c.
* Add the temperature.
* Formatting tweak.
* Fix the rotary embeddings.
Diffstat (limited to 'candle-examples/examples/llama2-c/model.rs')
-rw-r--r-- | candle-examples/examples/llama2-c/model.rs | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs index 2fb4b444..13f939db 100644 --- a/candle-examples/examples/llama2-c/model.rs +++ b/candle-examples/examples/llama2-c/model.rs @@ -30,12 +30,14 @@ impl Cache { pub fn new(use_kv_cache: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> { let freq_cis_real = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_real")?; let freq_cis_imag = vb.get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_imag")?; + let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?; + let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?; Ok(Self { masks: Arc::new(Mutex::new(HashMap::new())), use_kv_cache, kvs: Arc::new(Mutex::new(vec![None; cfg.n_layers])), - cos: freq_cis_real, - sin: freq_cis_imag, + cos, + sin, device: vb.device().clone(), }) } @@ -110,16 +112,17 @@ struct CausalSelfAttention { impl CausalSelfAttention { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let (b_sz, _, seq_len, n_embd) = x.dims4()?; + let (b_sz, seq_len, h, n_embd) = x.dims4()?; let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; - let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2))?; - let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2))?; - let x0 = x.narrow(D::Minus1, 0, n_embd / 2)?; - let x1 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?; + let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; + let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; + let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?; + let x0 = x.narrow(D::Minus1, 0, 1)?; + let x1 = x.narrow(D::Minus1, 1, 1)?; let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; - let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?; + let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?; Ok(rope) } |