diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-11-10 10:27:27 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-10 10:27:27 +0100 |
commit | d2c3f1477397b6730fbef7225dd9e5fc0a9fa096 (patch) | |
tree | 7e4142da75aeacb4b880a274b04e64763cd4d70a | |
parent | 26c4e5bf1d10532c9b681f07a7b08b2c84844bee (diff) | |
download | candle-d2c3f1477397b6730fbef7225dd9e5fc0a9fa096.tar.gz candle-d2c3f1477397b6730fbef7225dd9e5fc0a9fa096.tar.bz2 candle-d2c3f1477397b6730fbef7225dd9e5fc0a9fa096.zip |
Fix for flash-attn. (#1310)
Co-authored-by: laurent <laurent@par2dc5-ai-prd-cl01dgx02.cm.cluster>
-rw-r--r-- | candle-flash-attn/src/lib.rs | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 61980a58..3395bd0d 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -233,8 +233,8 @@ impl FlashAttnVarLen { let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout(); let seqlens_q = match &*seqlens_q { - candle::Storage::Cpu(_) => candle::bail!("seqlens_q must be a cuda tensor"), candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32! + _ => candle::bail!("seqlens_q must be a cuda tensor"), }; let seqlens_q = match seqlens_q_layout.contiguous_offsets() { Some((o1, o2)) => seqlens_q.slice(o1..o2), @@ -243,8 +243,8 @@ impl FlashAttnVarLen { let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout(); let seqlens_k = match &*seqlens_k { - candle::Storage::Cpu(_) => candle::bail!("seqlens_k must be a cuda tensor"), candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32! + _ => candle::bail!("seqlens_k must be a cuda tensor"), }; let seqlens_k = match seqlens_k_layout.contiguous_offsets() { Some((o1, o2)) => seqlens_k.slice(o1..o2), |