summaryrefslogtreecommitdiff
path: root/candle-examples/examples/falcon/model.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-22 11:39:27 +0200
committerGitHub <noreply@github.com>2023-07-22 10:39:27 +0100
commit43c72232927ca80c850a73ce977c2063d5a2dcf5 (patch)
treec93c07984e06b1925313f4f641a8b1a3956fc0ed /candle-examples/examples/falcon/model.rs
parent52c5d8c087f6a2ee91b807467860eb3e96bb6267 (diff)
downloadcandle-43c72232927ca80c850a73ce977c2063d5a2dcf5.tar.gz
candle-43c72232927ca80c850a73ce977c2063d5a2dcf5.tar.bz2
candle-43c72232927ca80c850a73ce977c2063d5a2dcf5.zip
Rename the .r functions to .dims so as to be a bit more explicit. (#220)
Diffstat (limited to 'candle-examples/examples/falcon/model.rs')
-rw-r--r--candle-examples/examples/falcon/model.rs8
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs
index 60821add..bce93c81 100644
--- a/candle-examples/examples/falcon/model.rs
+++ b/candle-examples/examples/falcon/model.rs
@@ -182,7 +182,7 @@ impl FalconRotaryEmbedding {
key: &Tensor,
past_kv_len: usize,
) -> Result<(Tensor, Tensor)> {
- let (_batch, seq_len, _head_dim) = query.shape().r3()?;
+ let (_batch, seq_len, _head_dim) = query.dims3()?;
let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?;
let cos = cos.narrow(0, past_kv_len, seq_len)?;
let sin = sin.narrow(0, past_kv_len, seq_len)?;
@@ -245,7 +245,7 @@ impl FalconAttention {
}
fn split_heads(&self, fused_qkv: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
- let (b_sz, seq_len, _) = fused_qkv.shape().r3()?;
+ let (b_sz, seq_len, _) = fused_qkv.dims3()?;
if !self.multi_query {
let fused_qkv = fused_qkv.reshape((b_sz, seq_len, self.num_heads, 3, self.head_dim))?;
let q = fused_qkv.narrow(D::Minus2, 0, 1)?.squeeze(D::Minus2)?;
@@ -267,7 +267,7 @@ impl FalconAttention {
let fused_qkv = self.query_key_value.forward(x)?;
let head_dim = self.head_dim;
let (query, key, value) = self.split_heads(&fused_qkv)?;
- let (b_sz, seq_len, _, _) = query.shape().r4()?;
+ let (b_sz, seq_len, _, _) = query.dims4()?;
let query = query
.transpose(1, 2)?
.reshape((b_sz * self.num_heads, seq_len, head_dim))?;
@@ -465,7 +465,7 @@ impl Falcon {
}
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
- let (b_sz, seq_len) = input_ids.shape().r2()?;
+ let (b_sz, seq_len) = input_ids.dims2()?;
let mut hidden_state = self.word_embeddings.forward(input_ids)?;
let past_kv_len = match &self.blocks[0].self_attention.kv_cache {
Some((k, _)) => k.dim(1)?,