diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-28 13:13:01 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-28 13:13:01 +0100 |
commit | 3eb2bc6d07f192a5ce73ab6964745275f2c15213 (patch) | |
tree | e5a682d0e40f3c258f668652082ff7fa45918e32 /candle-nn/tests | |
parent | 68eab38de6e5cabf17159a5dcf45ec703fbea441 (diff) | |
download | candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.tar.gz candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.tar.bz2 candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.zip |
Softmax numerical stability. (#267)
* Softmax numerical stability.
* Fix the flash-attn test.
Diffstat (limited to 'candle-nn/tests')
-rw-r--r-- | candle-nn/tests/ops.rs | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs new file mode 100644 index 00000000..ca82dd1f --- /dev/null +++ b/candle-nn/tests/ops.rs @@ -0,0 +1,62 @@ +use candle::{Device, Result, Tensor}; + +pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> { + let b = 10f32.powi(digits); + let t = t.to_vec3::<f32>()?; + let t = t + .iter() + .map(|t| { + t.iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect() + }) + .collect(); + Ok(t) +} + +#[test] +fn softmax() -> Result<()> { + let device = &Device::Cpu; + let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; + let tensor = Tensor::new(data, device)?; + let t0 = candle_nn::ops::softmax(&tensor.log()?, 0)?; + let t1 = candle_nn::ops::softmax(&tensor.log()?, 1)?; + let t2 = candle_nn::ops::softmax(&tensor.log()?, 2)?; + assert_eq!( + to_vec3_round(t0, 4)?, + &[ + // 3/5, 1/2, 4/11 + [[0.6, 0.5, 0.3636], [0.1111, 0.7143, 0.5294]], + // 2/5, 1/2, 7/11 + [[0.4, 0.5, 0.6364], [0.8889, 0.2857, 0.4706]] + ] + ); + assert_eq!( + to_vec3_round(t1, 4)?, + &[ + // 3/4, 1/6, 4/13 + [[0.75, 0.1667, 0.3077], [0.25, 0.8333, 0.6923]], + // 2/10, 1/3, 7/15 + [[0.2, 0.3333, 0.4667], [0.8, 0.6667, 0.5333]] + ] + ); + assert_eq!( + to_vec3_round(t2, 4)?, + &[ + // (3, 1, 4) / 8, (1, 5, 9) / 15 + [[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]], + // (2, 1, 7) / 10, (8, 2, 8) / 18 + [[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]] + ] + ); + Ok(()) +} + +#[test] +fn softmax_numerical_stability() -> Result<()> { + let dev = &Device::Cpu; + let xs = Tensor::new(&[1234f32, 0.], dev)?; + let softmax = candle_nn::ops::softmax(&xs, 0)?; + assert_eq!(softmax.to_vec1::<f32>()?, &[1f32, 0.]); + Ok(()) +} |