summaryrefslogtreecommitdiff
path: root/candle-flash-attn/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/src/lib.rs')
-rw-r--r--candle-flash-attn/src/lib.rs16
1 files changed, 8 insertions, 8 deletions
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs
index 0bbb451d..b159aee2 100644
--- a/candle-flash-attn/src/lib.rs
+++ b/candle-flash-attn/src/lib.rs
@@ -118,14 +118,14 @@ impl candle::CustomOp3 for FlashHdim32Sm80 {
/* k_batch_stride */ k_stride[0] as u32,
/* v_batch_stride */ v_stride[0] as u32,
/* o_batch_stride */ o_stride[0] as u32,
- /* q_row_stride */ q_stride[q_rank - 3] as u32,
- /* k_row_stride */ k_stride[k_rank - 3] as u32,
- /* v_row_stride */ v_stride[v_rank - 3] as u32,
- /* o_row_stride */ o_stride[o_rank - 3] as u32,
- /* q_head_stride */ q_stride[q_rank - 2] as u32,
- /* k_head_stride */ k_stride[k_rank - 2] as u32,
- /* v_head_stride */ v_stride[v_rank - 2] as u32,
- /* o_head_stride */ o_stride[o_rank - 2] as u32,
+ /* q_row_stride */ q_stride[q_rank - 3] as u32,
+ /* k_row_stride */ k_stride[k_rank - 3] as u32,
+ /* v_row_stride */ v_stride[v_rank - 3] as u32,
+ /* o_row_stride */ o_stride[o_rank - 3] as u32,
+ /* q_head_stride */ q_stride[q_rank - 2] as u32,
+ /* k_head_stride */ k_stride[k_rank - 2] as u32,
+ /* v_head_stride */ v_stride[v_rank - 2] as u32,
+ /* o_head_stride */ o_stride[o_rank - 2] as u32,
/* b */ b_sz as u32,
/* h */ num_heads as u32,
/* h_k */ num_heads_k as u32,