diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-22 11:39:27 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-22 10:39:27 +0100 |
commit | 43c72232927ca80c850a73ce977c2063d5a2dcf5 (patch) | |
tree | c93c07984e06b1925313f4f641a8b1a3956fc0ed /candle-examples/examples/falcon/model.rs | |
parent | 52c5d8c087f6a2ee91b807467860eb3e96bb6267 (diff) | |
download | candle-43c72232927ca80c850a73ce977c2063d5a2dcf5.tar.gz candle-43c72232927ca80c850a73ce977c2063d5a2dcf5.tar.bz2 candle-43c72232927ca80c850a73ce977c2063d5a2dcf5.zip |
Rename the .r functions to .dims so as to be a bit more explicit. (#220)
Diffstat (limited to 'candle-examples/examples/falcon/model.rs')
-rw-r--r-- | candle-examples/examples/falcon/model.rs | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs index 60821add..bce93c81 100644 --- a/candle-examples/examples/falcon/model.rs +++ b/candle-examples/examples/falcon/model.rs @@ -182,7 +182,7 @@ impl FalconRotaryEmbedding { key: &Tensor, past_kv_len: usize, ) -> Result<(Tensor, Tensor)> { - let (_batch, seq_len, _head_dim) = query.shape().r3()?; + let (_batch, seq_len, _head_dim) = query.dims3()?; let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?; let cos = cos.narrow(0, past_kv_len, seq_len)?; let sin = sin.narrow(0, past_kv_len, seq_len)?; @@ -245,7 +245,7 @@ impl FalconAttention { } fn split_heads(&self, fused_qkv: &Tensor) -> Result<(Tensor, Tensor, Tensor)> { - let (b_sz, seq_len, _) = fused_qkv.shape().r3()?; + let (b_sz, seq_len, _) = fused_qkv.dims3()?; if !self.multi_query { let fused_qkv = fused_qkv.reshape((b_sz, seq_len, self.num_heads, 3, self.head_dim))?; let q = fused_qkv.narrow(D::Minus2, 0, 1)?.squeeze(D::Minus2)?; @@ -267,7 +267,7 @@ impl FalconAttention { let fused_qkv = self.query_key_value.forward(x)?; let head_dim = self.head_dim; let (query, key, value) = self.split_heads(&fused_qkv)?; - let (b_sz, seq_len, _, _) = query.shape().r4()?; + let (b_sz, seq_len, _, _) = query.dims4()?; let query = query .transpose(1, 2)? .reshape((b_sz * self.num_heads, seq_len, head_dim))?; @@ -465,7 +465,7 @@ impl Falcon { } pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> { - let (b_sz, seq_len) = input_ids.shape().r2()?; + let (b_sz, seq_len) = input_ids.dims2()?; let mut hidden_state = self.word_embeddings.forward(input_ids)?; let past_kv_len = match &self.blocks[0].self_attention.kv_cache { Some((k, _)) => k.dim(1)?, |