summaryrefslogtreecommitdiff
path: root/candle-examples/examples/segment-anything/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/segment-anything/main.rs')
-rw-r--r--candle-examples/examples/segment-anything/main.rs19
1 files changed, 13 insertions, 6 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index de16f70c..368b5a33 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -8,9 +8,11 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
-mod model_image_encoder;
-mod model_mask_decoder;
-mod model_transformer;
+pub mod model_image_encoder;
+pub mod model_mask_decoder;
+pub mod model_prompt_encoder;
+pub mod model_sam;
+pub mod model_transformer;
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
@@ -82,7 +84,7 @@ impl Module for MlpBlock {
#[derive(Parser)]
struct Args {
#[arg(long)]
- model: Option<String>,
+ model: String,
#[arg(long)]
image: String,
@@ -95,10 +97,15 @@ struct Args {
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
- let _device = candle_examples::device(args.cpu)?;
+ let device = candle_examples::device(args.cpu)?;
- let image = candle_examples::imagenet::load_image224(args.image)?;
+ let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device);
println!("loaded image {image:?}");
+ let weights = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
+ let weights = weights.deserialize()?;
+ let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
+ let _sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
+
Ok(())
}