diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-28 13:13:01 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-28 13:13:01 +0100 |
commit | 3eb2bc6d07f192a5ce73ab6964745275f2c15213 (patch) | |
tree | e5a682d0e40f3c258f668652082ff7fa45918e32 /candle-wasm-examples/llama2-c | |
parent | 68eab38de6e5cabf17159a5dcf45ec703fbea441 (diff) | |
download | candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.tar.gz candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.tar.bz2 candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.zip |
Softmax numerical stability. (#267)
* Softmax numerical stability.
* Fix the flash-attn test.
Diffstat (limited to 'candle-wasm-examples/llama2-c')
-rw-r--r-- | candle-wasm-examples/llama2-c/src/model.rs | 2 | ||||
-rw-r--r-- | candle-wasm-examples/llama2-c/src/worker.rs | 4 |
2 files changed, 3 insertions, 3 deletions
diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index 8b0b3c3e..d95672b9 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -158,7 +158,7 @@ impl CausalSelfAttention { let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = att.softmax(D::Minus1)?; + let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs index d64da8c6..79f7c1fd 100644 --- a/candle-wasm-examples/llama2-c/src/worker.rs +++ b/candle-wasm-examples/llama2-c/src/worker.rs @@ -1,7 +1,7 @@ use crate::model::{Cache, Config, Llama}; use byteorder::{LittleEndian, ReadBytesExt}; use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D}; -use candle_nn::VarBuilder; +use candle_nn::{ops::softmax, VarBuilder}; use rand::{distributions::Distribution, SeedableRng}; use serde::{Deserialize, Serialize}; use wasm_bindgen::prelude::*; @@ -88,7 +88,7 @@ impl LogitsProcessor { pub fn sample(&mut self, logits: &Tensor) -> Result<u32> { let logits = logits.to_dtype(DType::F32)?; let next_token = if let Some(temperature) = self.temperature { - let prs = (&logits / temperature)?.softmax(D::Minus1)?; + let prs = softmax(&(&logits / temperature)?, D::Minus1)?; let prs: Vec<f32> = prs.to_vec1()?; let distr = rand::distributions::WeightedIndex::new(prs).map_err(candle::Error::wrap)?; |