summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama')
-rw-r--r--candle-examples/examples/llama/model.rs12
1 files changed, 6 insertions, 6 deletions
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index f3e30ec9..b074e5cb 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -116,11 +116,11 @@ impl RmsNorm {
let in_dtype = x.dtype();
// This is a no-op if x's dtype is already f32.
let x = x.to_dtype(DType::F32)?;
- let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
+ let (b_sz, seq_len, hidden_size) = x.dims3()?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?;
- let size = self.scale.shape().r1()?;
+ let size = self.scale.dims1()?;
let scale = self
.scale
.to_dtype(DType::F32)?
@@ -144,7 +144,7 @@ struct CausalSelfAttention {
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
- let (b_sz, _, seq_len, n_embd) = x.shape().r4()?;
+ let (b_sz, _, seq_len, n_embd) = x.dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
@@ -158,7 +158,7 @@ impl CausalSelfAttention {
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let x_dtype = x.dtype();
- let (b_sz, seq_len, n_embd) = x.shape().r3()?;
+ let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
@@ -219,7 +219,7 @@ impl CausalSelfAttention {
if n_rep == 1 {
Ok(x)
} else {
- let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().r4()?;
+ let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
@@ -345,7 +345,7 @@ impl Llama {
}
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
- let (_b_sz, seq_len) = x.shape().r2()?;
+ let (_b_sz, seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx)?;