summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/main.rs')
-rw-r--r--candle-examples/examples/segment-anything/main.rs7
1 files changed, 6 insertions, 1 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index 89d5b56c..c53c1010 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -110,7 +110,7 @@ pub fn main() -> anyhow::Result<()> {
let image = if args.image.ends_with(".safetensors") {
let mut tensors = candle::safetensors::load(&args.image, &device)?;
- match tensors.remove("image") {
+ let image = match tensors.remove("image") {
Some(image) => image,
None => {
if tensors.len() != 1 {
@@ -118,6 +118,11 @@ pub fn main() -> anyhow::Result<()> {
}
tensors.into_values().next().unwrap()
}
+ };
+ if image.rank() == 4 {
+ image.get(0)?
+ } else {
+ image
}
} else {
candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?