summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/model_mask_decoder.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-08 19:11:34 +0100
committerGitHub <noreply@github.com>2023-09-08 19:11:34 +0100
commit0906acab9186fbb14a2268e12dd66c13b0877f3e (patch)
tree6263cd54d7ca3e91d72e62b81ab27908b27b760d /candle-examples/examples/segment-anything/model_mask_decoder.rs
parent158ff3c609b22ed998dea5283738cc1ed13aa592 (diff)
downloadcandle-0906acab9186fbb14a2268e12dd66c13b0877f3e.tar.gz
candle-0906acab9186fbb14a2268e12dd66c13b0877f3e.tar.bz2
candle-0906acab9186fbb14a2268e12dd66c13b0877f3e.zip
Automatic mask generation (#779)
* A few more contiguous fixes for cuda. * Mask generation. * Generic bbox. * Generate all the masks.
Diffstat (limited to 'candle-examples/examples/segment-anything/model_mask_decoder.rs')
-rw-r--r--candle-examples/examples/segment-anything/model_mask_decoder.rs2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs
index 1f6d62a4..c02b44a7 100644
--- a/candle-examples/examples/segment-anything/model_mask_decoder.rs
+++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs
@@ -219,7 +219,7 @@ impl MaskDecoder {
let h = mlp.forward(&mask_tokens_out.i((.., i))?)?;
hyper_in_list.push(h)
}
- let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?;
+ let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?.contiguous()?;
let (b, c, h, w) = upscaled_embedding.dims4()?;
let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?;
let masks = masks.reshape((b, (), h, w))?;