diff options
Diffstat (limited to 'candle-flash-attn/src/lib.rs')
-rw-r--r-- | candle-flash-attn/src/lib.rs | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 0bbb451d..b159aee2 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -118,14 +118,14 @@ impl candle::CustomOp3 for FlashHdim32Sm80 { /* k_batch_stride */ k_stride[0] as u32, /* v_batch_stride */ v_stride[0] as u32, /* o_batch_stride */ o_stride[0] as u32, - /* q_row_stride */ q_stride[q_rank - 3] as u32, - /* k_row_stride */ k_stride[k_rank - 3] as u32, - /* v_row_stride */ v_stride[v_rank - 3] as u32, - /* o_row_stride */ o_stride[o_rank - 3] as u32, - /* q_head_stride */ q_stride[q_rank - 2] as u32, - /* k_head_stride */ k_stride[k_rank - 2] as u32, - /* v_head_stride */ v_stride[v_rank - 2] as u32, - /* o_head_stride */ o_stride[o_rank - 2] as u32, + /* q_row_stride */ q_stride[q_rank - 3] as u32, + /* k_row_stride */ k_stride[k_rank - 3] as u32, + /* v_row_stride */ v_stride[v_rank - 3] as u32, + /* o_row_stride */ o_stride[o_rank - 3] as u32, + /* q_head_stride */ q_stride[q_rank - 2] as u32, + /* k_head_stride */ k_stride[k_rank - 2] as u32, + /* v_head_stride */ v_stride[v_rank - 2] as u32, + /* o_head_stride */ o_stride[o_rank - 2] as u32, /* b */ b_sz as u32, /* h */ num_heads as u32, /* h_k */ num_heads_k as u32, |