summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/llama2-c
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-28 13:13:01 +0100
committerGitHub <noreply@github.com>2023-07-28 13:13:01 +0100
commit3eb2bc6d07f192a5ce73ab6964745275f2c15213 (patch)
treee5a682d0e40f3c258f668652082ff7fa45918e32 /candle-wasm-examples/llama2-c
parent68eab38de6e5cabf17159a5dcf45ec703fbea441 (diff)
downloadcandle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.tar.gz
candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.tar.bz2
candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.zip
Softmax numerical stability. (#267)
* Softmax numerical stability. * Fix the flash-attn test.
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r--candle-wasm-examples/llama2-c/src/model.rs2
-rw-r--r--candle-wasm-examples/llama2-c/src/worker.rs4
2 files changed, 3 insertions, 3 deletions
diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs
index 8b0b3c3e..d95672b9 100644
--- a/candle-wasm-examples/llama2-c/src/model.rs
+++ b/candle-wasm-examples/llama2-c/src/model.rs
@@ -158,7 +158,7 @@ impl CausalSelfAttention {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
- let att = att.softmax(D::Minus1)?;
+ let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs
index d64da8c6..79f7c1fd 100644
--- a/candle-wasm-examples/llama2-c/src/worker.rs
+++ b/candle-wasm-examples/llama2-c/src/worker.rs
@@ -1,7 +1,7 @@
use crate::model::{Cache, Config, Llama};
use byteorder::{LittleEndian, ReadBytesExt};
use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D};
-use candle_nn::VarBuilder;
+use candle_nn::{ops::softmax, VarBuilder};
use rand::{distributions::Distribution, SeedableRng};
use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::*;
@@ -88,7 +88,7 @@ impl LogitsProcessor {
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?;
let next_token = if let Some(temperature) = self.temperature {
- let prs = (&logits / temperature)?.softmax(D::Minus1)?;
+ let prs = softmax(&(&logits / temperature)?, D::Minus1)?;
let prs: Vec<f32> = prs.to_vec1()?;
let distr =
rand::distributions::WeightedIndex::new(prs).map_err(candle::Error::wrap)?;