summaryrefslogtreecommitdiff
path: root/candle-kernels/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-03 09:02:38 +0200
committerGitHub <noreply@github.com>2024-04-03 09:02:38 +0200
commit318d143224805e490d396874b9e1aaf28991393c (patch)
treeba51a3ef7b1f27734d8b3e5d5a434aab12c2fffd /candle-kernels/src
parent2be1a357102d8f64feb694720e5528d4974ca141 (diff)
downloadcandle-318d143224805e490d396874b9e1aaf28991393c.tar.gz
candle-318d143224805e490d396874b9e1aaf28991393c.tar.bz2
candle-318d143224805e490d396874b9e1aaf28991393c.zip
Relax the contiguous check for cuda kernels. (#2000)
* Relax the contiguous check for cuda kernels. * Ensure contiguity for RNNs. * Unrelated fix for segment anything. * Better error message + allow concatenating empty slices.
Diffstat (limited to 'candle-kernels/src')
-rw-r--r--candle-kernels/src/cuda_utils.cuh2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh
index b0a85249..2673b8aa 100644
--- a/candle-kernels/src/cuda_utils.cuh
+++ b/candle-kernels/src/cuda_utils.cuh
@@ -14,7 +14,7 @@ __device__ bool is_contiguous(
size_t acc = 1;
for (unsigned int d = 0; d < num_dims; d++) {
unsigned int dim_idx = num_dims - 1 - d;
- if (acc != strides[dim_idx]) {
+ if (dims[dim_idx] > 1 && acc != strides[dim_idx]) {
return false;
}
acc *= dims[dim_idx];