diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-08 09:39:10 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-08 09:39:10 +0100 |
commit | c1453f00b11c9dd12c5aa81fb4355ce47d22d477 (patch) | |
tree | 70fac003f636d26db3b1df5ef56980500700c0f1 /candle-examples/examples/segment-anything/main.rs | |
parent | 989a4807b151f08c651b5027cc1b547a59adf966 (diff) | |
download | candle-c1453f00b11c9dd12c5aa81fb4355ce47d22d477.tar.gz candle-c1453f00b11c9dd12c5aa81fb4355ce47d22d477.tar.bz2 candle-c1453f00b11c9dd12c5aa81fb4355ce47d22d477.zip |
Improve the safetensor loading in the segment-anything example. (#772)
* Improve the safetensor loading in the segment-anything example.
* Properly handle the labels when embedding the point prompts.
Diffstat (limited to 'candle-examples/examples/segment-anything/main.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/main.rs | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index 89d5b56c..c53c1010 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -110,7 +110,7 @@ pub fn main() -> anyhow::Result<()> { let image = if args.image.ends_with(".safetensors") { let mut tensors = candle::safetensors::load(&args.image, &device)?; - match tensors.remove("image") { + let image = match tensors.remove("image") { Some(image) => image, None => { if tensors.len() != 1 { @@ -118,6 +118,11 @@ pub fn main() -> anyhow::Result<()> { } tensors.into_values().next().unwrap() } + }; + if image.rank() == 4 { + image.get(0)? + } else { + image } } else { candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)? |