summaryrefslogtreecommitdiff
path: root/candle-flash-attn/tests/flash_attn_tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/tests/flash_attn_tests.rs')
-rw-r--r--candle-flash-attn/tests/flash_attn_tests.rs52
1 files changed, 52 insertions, 0 deletions
diff --git a/candle-flash-attn/tests/flash_attn_tests.rs b/candle-flash-attn/tests/flash_attn_tests.rs
index 250added..e3058611 100644
--- a/candle-flash-attn/tests/flash_attn_tests.rs
+++ b/candle-flash-attn/tests/flash_attn_tests.rs
@@ -27,6 +27,20 @@ fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result<
Ok(output)
}
+fn fa_acausal_softcap(q: &Tensor, k: &Tensor, v: &Tensor, softcap: f32) -> Result<Tensor> {
+ let in_dtype = q.dtype();
+ let q = q.to_dtype(DType::F32)?;
+ let k = k.to_dtype(DType::F32)?;
+ let v = v.to_dtype(DType::F32)?;
+ // let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
+ let att = q.matmul(&k.t()?)?;
+ let att = (softcap as f64 * ((att / softcap as f64)?.tanh())?)?;
+ let att = candle_nn::ops::softmax(&att, D::Minus1)?;
+ // Convert to contiguous as matmul doesn't support strided vs for now.
+ let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?;
+ Ok(output)
+}
+
#[test]
fn flash_attn_acausal() -> Result<()> {
let device = Device::new_cuda(0)?;
@@ -90,6 +104,44 @@ fn flash_attn_acausal() -> Result<()> {
}
#[test]
+fn flash_attn_acausal_softcap() -> Result<()> {
+ let device = Device::new_cuda(0)?;
+ let q = Tensor::arange(0u32, 3 * 5 * 8, &device)?
+ .to_dtype(DType::F16)?
+ .reshape((1, 3, 5, 8))?;
+ let k = (&q / 40.)?;
+ let v = (&q / 50.)?;
+ let q = (&q / 30.)?;
+ let softcap = 5.0f32;
+
+ let ys1 = fa_acausal_softcap(&q, &k, &v, softcap.clone())?;
+ let ys1 = ys1.i(0)?.to_dtype(DType::F32)?;
+ let ys2 = {
+ let q = q.transpose(1, 2)?;
+ let k = k.transpose(1, 2)?;
+ let v = v.transpose(1, 2)?;
+ candle_flash_attn::flash_attn_alibi_windowed_softcap(
+ &q,
+ &k,
+ &v,
+ None, // alibi_slopes //
+ 1.0, // softmax //
+ None, // window_size_left //
+ None, // window_size_right //
+ softcap.clone(), // softcap //
+ )?
+ .transpose(1, 2)?
+ };
+ let ys2 = ys2.i(0)?.to_dtype(DType::F32)?;
+ let diff = ys1.sub(&ys2)?.abs()?.flatten_all()?.max(0)?;
+
+ assert_eq!(ys1.dims(), &[3, 5, 8]);
+ assert_eq!(ys2.dims(), &[3, 5, 8]);
+ assert!(diff.to_vec0::<f32>()?.abs() < 1e-3);
+ Ok(())
+}
+
+#[test]
fn flash_attn_varlen() -> Result<()> {
let device = Device::new_cuda(0)?;
let q = Tensor::arange(0u32, 48, &device)?