diff options
Diffstat (limited to 'candle-examples/examples/llama/main.rs')
-rw-r--r-- | candle-examples/examples/llama/main.rs | 17 |
1 files changed, 8 insertions, 9 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 73db15e0..d254eeed 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -283,19 +283,18 @@ impl CausalSelfAttention { dims.push(v / 2); dims.push(2); let x = x.reshape(dims)?; - let rank = x.rank(); - let re_x = x.narrow(rank - 1, 0, 1)?; - let im_x = x.narrow(rank - 1, 1, 1)?; + let re_x = x.narrow(candle::D::Minus1, 0, 1)?; + let im_x = x.narrow(candle::D::Minus1, 1, 1)?; let re_f = freqs_cis - .narrow(rank - 1, 0, 1)? + .narrow(candle::D::Minus1, 0, 1)? .broadcast_as(re_x.shape())?; let im_f = freqs_cis - .narrow(rank - 1, 1, 1)? + .narrow(candle::D::Minus1, 1, 1)? .broadcast_as(im_x.shape())?; let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?; let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?; - let rope = Tensor::cat(&[&re, &im], rank - 1)?; - let rope = rope.flatten(Some(rope.rank() - 2), None)?; + let rope = Tensor::cat(&[&re, &im], re.rank() - 1)?; + let rope = rope.flatten_from(candle::D::Minus2)?; Ok(rope) } @@ -339,7 +338,7 @@ impl CausalSelfAttention { let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?; let mask = self.cache.mask(t)?.broadcast_as(att.shape())?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = att.softmax(att.rank() - 1)?; + let att = att.softmax(candle::D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; let y = y.transpose(0, 1)?.reshape(&[t, c])?; @@ -537,7 +536,7 @@ async fn main() -> Result<()> { let next_token = if let Some(temperature) = args.temperature { println!("Sampling with temperature {temperature:?}"); - let prs = (&logits / temperature)?.softmax(logits.rank() - 1)?; + let prs = (&logits / temperature)?.softmax(candle::D::Minus1)?; let logits_v: Vec<f32> = prs.to_vec1()?; let distr = rand::distributions::WeightedIndex::new(&logits_v)?; |