diff options
-rw-r--r-- | Cargo.toml | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/segment_anything/sam.rs | 28 | ||||
-rw-r--r-- | candle-wasm-examples/segment-anything/Cargo.toml | 29 | ||||
-rw-r--r-- | candle-wasm-examples/segment-anything/build-lib.sh | 2 | ||||
-rw-r--r-- | candle-wasm-examples/segment-anything/src/bin/m.rs | 113 | ||||
-rw-r--r-- | candle-wasm-examples/segment-anything/src/lib.rs | 19 |
6 files changed, 189 insertions, 3 deletions
@@ -8,6 +8,7 @@ members = [ "candle-pyo3", "candle-transformers", "candle-wasm-examples/llama2-c", + "candle-wasm-examples/segment-anything", "candle-wasm-examples/whisper", "candle-wasm-examples/yolo", ] diff --git a/candle-transformers/src/models/segment_anything/sam.rs b/candle-transformers/src/models/segment_anything/sam.rs index c40473e3..92756591 100644 --- a/candle-transformers/src/models/segment_anything/sam.rs +++ b/candle-transformers/src/models/segment_anything/sam.rs @@ -122,6 +122,11 @@ impl Sam { }) } + pub fn embeddings(&self, img: &Tensor) -> Result<Tensor> { + let img = self.preprocess(img)?.unsqueeze(0)?; + self.image_encoder.forward(&img) + } + pub fn forward( &self, img: &Tensor, @@ -131,15 +136,32 @@ impl Sam { let (_c, original_h, original_w) = img.dims3()?; let img = self.preprocess(img)?.unsqueeze(0)?; let img_embeddings = self.image_encoder.forward(&img)?; + self.forward_for_embeddings( + &img_embeddings, + original_h, + original_w, + point, + multimask_output, + ) + } + + pub fn forward_for_embeddings( + &self, + img_embeddings: &Tensor, + original_h: usize, + original_w: usize, + point: Option<(f64, f64)>, + multimask_output: bool, + ) -> Result<(Tensor, Tensor)> { let image_pe = self.prompt_encoder.get_dense_pe()?; let points = match point { None => None, Some((x, y)) => { let points = Tensor::new( &[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]], - img.device(), + img_embeddings.device(), )?; - let labels = Tensor::ones((1, 1), DType::F32, img.device())?; + let labels = Tensor::ones((1, 1), DType::F32, img_embeddings.device())?; Some((points, labels)) } }; @@ -147,7 +169,7 @@ impl Sam { let (sparse_prompt_embeddings, dense_prompt_embeddings) = self.prompt_encoder.forward(points, None, None)?; let (low_res_mask, iou_predictions) = self.mask_decoder.forward( - &img_embeddings, + img_embeddings, &image_pe, &sparse_prompt_embeddings, &dense_prompt_embeddings, diff --git a/candle-wasm-examples/segment-anything/Cargo.toml b/candle-wasm-examples/segment-anything/Cargo.toml new file mode 100644 index 00000000..ab82ab1f --- /dev/null +++ b/candle-wasm-examples/segment-anything/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "candle-wasm-example-sam" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true + +[dependencies] +candle = { path = "../../candle-core", version = "0.2.1", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.2.1" } +candle-transformers = { path = "../../candle-transformers", version = "0.2.1" } +num-traits = { workspace = true } + +# App crates. +anyhow = { workspace = true } +byteorder = { workspace = true } +getrandom = { version = "0.2", features = ["js"] } +image = { workspace = true } +log = { workspace = true } +safetensors = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } + +# Wasm specific crates. +console_error_panic_hook = "0.1.7" +wasm-bindgen = "0.2.87" diff --git a/candle-wasm-examples/segment-anything/build-lib.sh b/candle-wasm-examples/segment-anything/build-lib.sh new file mode 100644 index 00000000..b0ebb182 --- /dev/null +++ b/candle-wasm-examples/segment-anything/build-lib.sh @@ -0,0 +1,2 @@ +cargo build --target wasm32-unknown-unknown --release +wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs new file mode 100644 index 00000000..c4c79fe0 --- /dev/null +++ b/candle-wasm-examples/segment-anything/src/bin/m.rs @@ -0,0 +1,113 @@ +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_wasm_example_sam as sam; +use wasm_bindgen::prelude::*; + +#[allow(unused)] +struct Embeddings { + original_width: u32, + original_height: u32, + width: u32, + height: u32, + data: Tensor, +} + +#[wasm_bindgen] +pub struct Model { + sam: sam::Sam, + embeddings: Option<Embeddings>, +} + +#[wasm_bindgen] +impl Model { + #[wasm_bindgen(constructor)] + pub fn new(weights: &[u8], use_tiny: bool) -> Result<Model, JsError> { + console_error_panic_hook::set_once(); + let dev = &Device::Cpu; + let weights = safetensors::tensor::SafeTensors::deserialize(weights)?; + let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev); + let sam = if use_tiny { + sam::Sam::new_tiny(vb)? // tiny vit_t + } else { + sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b + }; + Ok(Self { + sam, + embeddings: None, + }) + } + + pub fn set_image_embeddings(&mut self, image_data: Vec<u8>) -> Result<(), JsError> { + sam::console_log!("image data: {}", image_data.len()); + let image_data = std::io::Cursor::new(image_data); + let image = image::io::Reader::new(image_data) + .with_guessed_format()? + .decode() + .map_err(candle::Error::wrap)?; + let (original_height, original_width) = (image.height(), image.width()); + let (height, width) = (original_height, original_width); + let resize_longest = sam::IMAGE_SIZE as u32; + let (height, width) = if height < width { + let h = (resize_longest * height) / width; + (h, resize_longest) + } else { + let w = (resize_longest * width) / height; + (resize_longest, w) + }; + let image_t = { + let img = image.resize_exact(width, height, image::imageops::FilterType::CatmullRom); + let data = img.to_rgb8().into_raw(); + Tensor::from_vec( + data, + (img.height() as usize, img.width() as usize, 3), + &Device::Cpu, + )? + .permute((2, 0, 1))? + }; + let data = self.sam.embeddings(&image_t)?; + self.embeddings = Some(Embeddings { + original_width, + original_height, + width, + height, + data, + }); + Ok(()) + } + + // x and y have to be between 0 and 1 + pub fn mask_for_point(&self, x: f64, y: f64) -> Result<String, JsError> { + let embeddings = match &self.embeddings { + None => todo!(), + Some(embeddings) => embeddings, + }; + let (mask, iou_predictions) = self.sam.forward_for_embeddings( + &embeddings.data, + embeddings.height as usize, + embeddings.width as usize, + Some((x, y)), + false, + )?; + let iou = iou_predictions.to_vec1::<f32>()?[0]; + let mask_shape = mask.dims().to_vec(); + let mask_data = mask.ge(0f32)?.flatten_all()?.to_vec1::<u8>()?; + let mask = Mask { + iou, + mask_shape, + mask_data, + }; + let json = serde_json::to_string(&mask)?; + Ok(json) + } +} + +#[derive(serde::Serialize, serde::Deserialize)] +struct Mask { + iou: f32, + mask_shape: Vec<usize>, + mask_data: Vec<u8>, +} + +fn main() { + console_error_panic_hook::set_once(); +} diff --git a/candle-wasm-examples/segment-anything/src/lib.rs b/candle-wasm-examples/segment-anything/src/lib.rs new file mode 100644 index 00000000..0f4f96fd --- /dev/null +++ b/candle-wasm-examples/segment-anything/src/lib.rs @@ -0,0 +1,19 @@ +use candle_transformers::models::segment_anything::sam; +use wasm_bindgen::prelude::*; + +pub use sam::{Sam, IMAGE_SIZE}; + +#[wasm_bindgen] +extern "C" { + // Use `js_namespace` here to bind `console.log(..)` instead of just + // `log(..)` + #[wasm_bindgen(js_namespace = console)] + pub fn log(s: &str); +} + +#[macro_export] +macro_rules! console_log { + // Note that this is using the `log` function imported above during + // `bare_bones` + ($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string())) +} |