summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-11-10 10:27:27 +0100
committerGitHub <noreply@github.com>2023-11-10 10:27:27 +0100
commitd2c3f1477397b6730fbef7225dd9e5fc0a9fa096 (patch)
tree7e4142da75aeacb4b880a274b04e64763cd4d70a
parent26c4e5bf1d10532c9b681f07a7b08b2c84844bee (diff)
downloadcandle-d2c3f1477397b6730fbef7225dd9e5fc0a9fa096.tar.gz
candle-d2c3f1477397b6730fbef7225dd9e5fc0a9fa096.tar.bz2
candle-d2c3f1477397b6730fbef7225dd9e5fc0a9fa096.zip
Fix for flash-attn. (#1310)
Co-authored-by: laurent <laurent@par2dc5-ai-prd-cl01dgx02.cm.cluster>
-rw-r--r--candle-flash-attn/src/lib.rs4
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs
index 61980a58..3395bd0d 100644
--- a/candle-flash-attn/src/lib.rs
+++ b/candle-flash-attn/src/lib.rs
@@ -233,8 +233,8 @@ impl FlashAttnVarLen {
let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout();
let seqlens_q = match &*seqlens_q {
- candle::Storage::Cpu(_) => candle::bail!("seqlens_q must be a cuda tensor"),
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
+ _ => candle::bail!("seqlens_q must be a cuda tensor"),
};
let seqlens_q = match seqlens_q_layout.contiguous_offsets() {
Some((o1, o2)) => seqlens_q.slice(o1..o2),
@@ -243,8 +243,8 @@ impl FlashAttnVarLen {
let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout();
let seqlens_k = match &*seqlens_k {
- candle::Storage::Cpu(_) => candle::bail!("seqlens_k must be a cuda tensor"),
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
+ _ => candle::bail!("seqlens_k must be a cuda tensor"),
};
let seqlens_k = match seqlens_k_layout.contiguous_offsets() {
Some((o1, o2)) => seqlens_k.slice(o1..o2),