diff options
Diffstat (limited to 'candle-examples/examples/segment-anything/main.rs')
-rw-r--r-- | candle-examples/examples/segment-anything/main.rs | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index de16f70c..368b5a33 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -8,9 +8,11 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -mod model_image_encoder; -mod model_mask_decoder; -mod model_transformer; +pub mod model_image_encoder; +pub mod model_mask_decoder; +pub mod model_prompt_encoder; +pub mod model_sam; +pub mod model_transformer; use candle::{DType, IndexOp, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; @@ -82,7 +84,7 @@ impl Module for MlpBlock { #[derive(Parser)] struct Args { #[arg(long)] - model: Option<String>, + model: String, #[arg(long)] image: String, @@ -95,10 +97,15 @@ struct Args { pub fn main() -> anyhow::Result<()> { let args = Args::parse(); - let _device = candle_examples::device(args.cpu)?; + let device = candle_examples::device(args.cpu)?; - let image = candle_examples::imagenet::load_image224(args.image)?; + let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device); println!("loaded image {image:?}"); + let weights = unsafe { candle::safetensors::MmapedFile::new(args.model)? }; + let weights = weights.deserialize()?; + let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); + let _sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b + Ok(()) } |