diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-22 11:39:27 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-22 10:39:27 +0100 |
commit | 43c72232927ca80c850a73ce977c2063d5a2dcf5 (patch) | |
tree | c93c07984e06b1925313f4f641a8b1a3956fc0ed /candle-examples/examples/llama | |
parent | 52c5d8c087f6a2ee91b807467860eb3e96bb6267 (diff) | |
download | candle-43c72232927ca80c850a73ce977c2063d5a2dcf5.tar.gz candle-43c72232927ca80c850a73ce977c2063d5a2dcf5.tar.bz2 candle-43c72232927ca80c850a73ce977c2063d5a2dcf5.zip |
Rename the .r functions to .dims so as to be a bit more explicit. (#220)
Diffstat (limited to 'candle-examples/examples/llama')
-rw-r--r-- | candle-examples/examples/llama/model.rs | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index f3e30ec9..b074e5cb 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -116,11 +116,11 @@ impl RmsNorm { let in_dtype = x.dtype(); // This is a no-op if x's dtype is already f32. let x = x.to_dtype(DType::F32)?; - let (b_sz, seq_len, hidden_size) = x.shape().r3()?; + let (b_sz, seq_len, hidden_size) = x.dims3()?; let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?; - let size = self.scale.shape().r1()?; + let size = self.scale.dims1()?; let scale = self .scale .to_dtype(DType::F32)? @@ -144,7 +144,7 @@ struct CausalSelfAttention { impl CausalSelfAttention { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let (b_sz, _, seq_len, n_embd) = x.shape().r4()?; + let (b_sz, _, seq_len, n_embd) = x.dims4()?; let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?; @@ -158,7 +158,7 @@ impl CausalSelfAttention { fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { let x_dtype = x.dtype(); - let (b_sz, seq_len, n_embd) = x.shape().r3()?; + let (b_sz, seq_len, n_embd) = x.dims3()?; let q = self.q_proj.forward(x)?; let k = self.k_proj.forward(x)?; let v = self.v_proj.forward(x)?; @@ -219,7 +219,7 @@ impl CausalSelfAttention { if n_rep == 1 { Ok(x) } else { - let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().r4()?; + let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?; let x = x .unsqueeze(2)? .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? @@ -345,7 +345,7 @@ impl Llama { } pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let (_b_sz, seq_len) = x.shape().r2()?; + let (_b_sz, seq_len) = x.dims2()?; let mut x = self.wte.forward(x)?; for (block_idx, block) in self.blocks.iter().enumerate() { x = block.forward(&x, index_pos, block_idx)?; |