summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-03 09:02:38 +0200
committerGitHub <noreply@github.com>2024-04-03 09:02:38 +0200
commit318d143224805e490d396874b9e1aaf28991393c (patch)
treeba51a3ef7b1f27734d8b3e5d5a434aab12c2fffd /candle-nn
parent2be1a357102d8f64feb694720e5528d4974ca141 (diff)
downloadcandle-318d143224805e490d396874b9e1aaf28991393c.tar.gz
candle-318d143224805e490d396874b9e1aaf28991393c.tar.bz2
candle-318d143224805e490d396874b9e1aaf28991393c.zip
Relax the contiguous check for cuda kernels. (#2000)
* Relax the contiguous check for cuda kernels. * Ensure contiguity for RNNs. * Unrelated fix for segment anything. * Better error message + allow concatenating empty slices.
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/rnn.rs2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs
index 07795eda..dbfa639b 100644
--- a/candle-nn/src/rnn.rs
+++ b/candle-nn/src/rnn.rs
@@ -31,7 +31,7 @@ pub trait RNN {
let (_b_size, seq_len, _features) = input.dims3()?;
let mut output = Vec::with_capacity(seq_len);
for seq_index in 0..seq_len {
- let input = input.i((.., seq_index, ..))?;
+ let input = input.i((.., seq_index, ..))?.contiguous()?;
let state = if seq_index == 0 {
self.step(&input, init_state)?
} else {