summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-03 09:02:38 +0200
committerGitHub <noreply@github.com>2024-04-03 09:02:38 +0200
commit318d143224805e490d396874b9e1aaf28991393c (patch)
treeba51a3ef7b1f27734d8b3e5d5a434aab12c2fffd
parent2be1a357102d8f64feb694720e5528d4974ca141 (diff)
downloadcandle-318d143224805e490d396874b9e1aaf28991393c.tar.gz
candle-318d143224805e490d396874b9e1aaf28991393c.tar.bz2
candle-318d143224805e490d396874b9e1aaf28991393c.zip
Relax the contiguous check for cuda kernels. (#2000)
* Relax the contiguous check for cuda kernels. * Ensure contiguity for RNNs. * Unrelated fix for segment anything. * Better error message + allow concatenating empty slices.
-rw-r--r--candle-core/src/cuda_backend/mod.rs7
-rw-r--r--candle-kernels/src/cuda_utils.cuh2
-rw-r--r--candle-nn/src/rnn.rs2
-rw-r--r--candle-transformers/src/models/segment_anything/prompt_encoder.rs3
4 files changed, 10 insertions, 4 deletions
diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs
index 3690e0dc..6a9e73f8 100644
--- a/candle-core/src/cuda_backend/mod.rs
+++ b/candle-core/src/cuda_backend/mod.rs
@@ -99,7 +99,7 @@ pub trait WrapErr<O> {
impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
fn w(self) -> std::result::Result<O, crate::Error> {
- self.map_err(|e| crate::Error::Cuda(Box::new(e.into())))
+ self.map_err(|e| crate::Error::Cuda(Box::new(e.into())).bt())
}
}
@@ -1761,6 +1761,11 @@ impl BackendStorage for CudaStorage {
let dev = &self.device;
let d1 = d1 as u32;
let d2 = d2 as u32;
+ // Nothing to copy so we exit early to avoid launching a kernel and some potential invalid
+ // argument with a null pointer.
+ if d1 == 0 || d2 == 0 {
+ return Ok(());
+ }
let dst_s = dst_s as u32;
let src_s = src_s as u32;
let (src, dst, kname) = match (&self.slice, &mut dst.slice) {
diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh
index b0a85249..2673b8aa 100644
--- a/candle-kernels/src/cuda_utils.cuh
+++ b/candle-kernels/src/cuda_utils.cuh
@@ -14,7 +14,7 @@ __device__ bool is_contiguous(
size_t acc = 1;
for (unsigned int d = 0; d < num_dims; d++) {
unsigned int dim_idx = num_dims - 1 - d;
- if (acc != strides[dim_idx]) {
+ if (dims[dim_idx] > 1 && acc != strides[dim_idx]) {
return false;
}
acc *= dims[dim_idx];
diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs
index 07795eda..dbfa639b 100644
--- a/candle-nn/src/rnn.rs
+++ b/candle-nn/src/rnn.rs
@@ -31,7 +31,7 @@ pub trait RNN {
let (_b_size, seq_len, _features) = input.dims3()?;
let mut output = Vec::with_capacity(seq_len);
for seq_index in 0..seq_len {
- let input = input.i((.., seq_index, ..))?;
+ let input = input.i((.., seq_index, ..))?.contiguous()?;
let state = if seq_index == 0 {
self.step(&input, init_state)?
} else {
diff --git a/candle-transformers/src/models/segment_anything/prompt_encoder.rs b/candle-transformers/src/models/segment_anything/prompt_encoder.rs
index 16e8a4e8..258fb5aa 100644
--- a/candle-transformers/src/models/segment_anything/prompt_encoder.rs
+++ b/candle-transformers/src/models/segment_anything/prompt_encoder.rs
@@ -218,7 +218,8 @@ impl PromptEncoder {
(Some(se_points), None) => se_points,
(None, Some(se_boxes)) => se_boxes,
(None, None) => {
- Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)?
+ let dev = self.no_mask_embed.embeddings().device();
+ Tensor::zeros((1, 0, self.embed_dim), DType::F32, dev)?
}
};