diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-10 12:29:37 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-10 12:29:37 +0100 |
commit | 584171cae18923450e029eb04245f70c7e7e5fc4 (patch) | |
tree | 84109b5f03ac96839967cdc5d5c21ce6f25fcb3c /candle-wasm-examples | |
parent | 6c58fc59fd828492021cfd0f4518ae5ae3b03f56 (diff) | |
download | candle-584171cae18923450e029eb04245f70c7e7e5fc4.tar.gz candle-584171cae18923450e029eb04245f70c7e7e5fc4.tar.bz2 candle-584171cae18923450e029eb04245f70c7e7e5fc4.zip |
Add a wasm module for the segment anything example. (#797)
Diffstat (limited to 'candle-wasm-examples')
-rw-r--r-- | candle-wasm-examples/segment-anything/Cargo.toml | 29 | ||||
-rw-r--r-- | candle-wasm-examples/segment-anything/build-lib.sh | 2 | ||||
-rw-r--r-- | candle-wasm-examples/segment-anything/src/bin/m.rs | 113 | ||||
-rw-r--r-- | candle-wasm-examples/segment-anything/src/lib.rs | 19 |
4 files changed, 163 insertions, 0 deletions
diff --git a/candle-wasm-examples/segment-anything/Cargo.toml b/candle-wasm-examples/segment-anything/Cargo.toml new file mode 100644 index 00000000..ab82ab1f --- /dev/null +++ b/candle-wasm-examples/segment-anything/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "candle-wasm-example-sam" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true + +[dependencies] +candle = { path = "../../candle-core", version = "0.2.1", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.2.1" } +candle-transformers = { path = "../../candle-transformers", version = "0.2.1" } +num-traits = { workspace = true } + +# App crates. +anyhow = { workspace = true } +byteorder = { workspace = true } +getrandom = { version = "0.2", features = ["js"] } +image = { workspace = true } +log = { workspace = true } +safetensors = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } + +# Wasm specific crates. +console_error_panic_hook = "0.1.7" +wasm-bindgen = "0.2.87" diff --git a/candle-wasm-examples/segment-anything/build-lib.sh b/candle-wasm-examples/segment-anything/build-lib.sh new file mode 100644 index 00000000..b0ebb182 --- /dev/null +++ b/candle-wasm-examples/segment-anything/build-lib.sh @@ -0,0 +1,2 @@ +cargo build --target wasm32-unknown-unknown --release +wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs new file mode 100644 index 00000000..c4c79fe0 --- /dev/null +++ b/candle-wasm-examples/segment-anything/src/bin/m.rs @@ -0,0 +1,113 @@ +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_wasm_example_sam as sam; +use wasm_bindgen::prelude::*; + +#[allow(unused)] +struct Embeddings { + original_width: u32, + original_height: u32, + width: u32, + height: u32, + data: Tensor, +} + +#[wasm_bindgen] +pub struct Model { + sam: sam::Sam, + embeddings: Option<Embeddings>, +} + +#[wasm_bindgen] +impl Model { + #[wasm_bindgen(constructor)] + pub fn new(weights: &[u8], use_tiny: bool) -> Result<Model, JsError> { + console_error_panic_hook::set_once(); + let dev = &Device::Cpu; + let weights = safetensors::tensor::SafeTensors::deserialize(weights)?; + let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev); + let sam = if use_tiny { + sam::Sam::new_tiny(vb)? // tiny vit_t + } else { + sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b + }; + Ok(Self { + sam, + embeddings: None, + }) + } + + pub fn set_image_embeddings(&mut self, image_data: Vec<u8>) -> Result<(), JsError> { + sam::console_log!("image data: {}", image_data.len()); + let image_data = std::io::Cursor::new(image_data); + let image = image::io::Reader::new(image_data) + .with_guessed_format()? + .decode() + .map_err(candle::Error::wrap)?; + let (original_height, original_width) = (image.height(), image.width()); + let (height, width) = (original_height, original_width); + let resize_longest = sam::IMAGE_SIZE as u32; + let (height, width) = if height < width { + let h = (resize_longest * height) / width; + (h, resize_longest) + } else { + let w = (resize_longest * width) / height; + (resize_longest, w) + }; + let image_t = { + let img = image.resize_exact(width, height, image::imageops::FilterType::CatmullRom); + let data = img.to_rgb8().into_raw(); + Tensor::from_vec( + data, + (img.height() as usize, img.width() as usize, 3), + &Device::Cpu, + )? + .permute((2, 0, 1))? + }; + let data = self.sam.embeddings(&image_t)?; + self.embeddings = Some(Embeddings { + original_width, + original_height, + width, + height, + data, + }); + Ok(()) + } + + // x and y have to be between 0 and 1 + pub fn mask_for_point(&self, x: f64, y: f64) -> Result<String, JsError> { + let embeddings = match &self.embeddings { + None => todo!(), + Some(embeddings) => embeddings, + }; + let (mask, iou_predictions) = self.sam.forward_for_embeddings( + &embeddings.data, + embeddings.height as usize, + embeddings.width as usize, + Some((x, y)), + false, + )?; + let iou = iou_predictions.to_vec1::<f32>()?[0]; + let mask_shape = mask.dims().to_vec(); + let mask_data = mask.ge(0f32)?.flatten_all()?.to_vec1::<u8>()?; + let mask = Mask { + iou, + mask_shape, + mask_data, + }; + let json = serde_json::to_string(&mask)?; + Ok(json) + } +} + +#[derive(serde::Serialize, serde::Deserialize)] +struct Mask { + iou: f32, + mask_shape: Vec<usize>, + mask_data: Vec<u8>, +} + +fn main() { + console_error_panic_hook::set_once(); +} diff --git a/candle-wasm-examples/segment-anything/src/lib.rs b/candle-wasm-examples/segment-anything/src/lib.rs new file mode 100644 index 00000000..0f4f96fd --- /dev/null +++ b/candle-wasm-examples/segment-anything/src/lib.rs @@ -0,0 +1,19 @@ +use candle_transformers::models::segment_anything::sam; +use wasm_bindgen::prelude::*; + +pub use sam::{Sam, IMAGE_SIZE}; + +#[wasm_bindgen] +extern "C" { + // Use `js_namespace` here to bind `console.log(..)` instead of just + // `log(..)` + #[wasm_bindgen(js_namespace = console)] + pub fn log(s: &str); +} + +#[macro_export] +macro_rules! console_log { + // Note that this is using the `log` function imported above during + // `bare_bones` + ($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string())) +} |