summaryrefslogtreecommitdiff
path: root/candle-wasm-examples
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-10 12:29:37 +0100
committerGitHub <noreply@github.com>2023-09-10 12:29:37 +0100
commit584171cae18923450e029eb04245f70c7e7e5fc4 (patch)
tree84109b5f03ac96839967cdc5d5c21ce6f25fcb3c /candle-wasm-examples
parent6c58fc59fd828492021cfd0f4518ae5ae3b03f56 (diff)
downloadcandle-584171cae18923450e029eb04245f70c7e7e5fc4.tar.gz
candle-584171cae18923450e029eb04245f70c7e7e5fc4.tar.bz2
candle-584171cae18923450e029eb04245f70c7e7e5fc4.zip
Add a wasm module for the segment anything example. (#797)
Diffstat (limited to 'candle-wasm-examples')
-rw-r--r--candle-wasm-examples/segment-anything/Cargo.toml29
-rw-r--r--candle-wasm-examples/segment-anything/build-lib.sh2
-rw-r--r--candle-wasm-examples/segment-anything/src/bin/m.rs113
-rw-r--r--candle-wasm-examples/segment-anything/src/lib.rs19
4 files changed, 163 insertions, 0 deletions
diff --git a/candle-wasm-examples/segment-anything/Cargo.toml b/candle-wasm-examples/segment-anything/Cargo.toml
new file mode 100644
index 00000000..ab82ab1f
--- /dev/null
+++ b/candle-wasm-examples/segment-anything/Cargo.toml
@@ -0,0 +1,29 @@
+[package]
+name = "candle-wasm-example-sam"
+version.workspace = true
+edition.workspace = true
+description.workspace = true
+repository.workspace = true
+keywords.workspace = true
+categories.workspace = true
+license.workspace = true
+
+[dependencies]
+candle = { path = "../../candle-core", version = "0.2.1", package = "candle-core" }
+candle-nn = { path = "../../candle-nn", version = "0.2.1" }
+candle-transformers = { path = "../../candle-transformers", version = "0.2.1" }
+num-traits = { workspace = true }
+
+# App crates.
+anyhow = { workspace = true }
+byteorder = { workspace = true }
+getrandom = { version = "0.2", features = ["js"] }
+image = { workspace = true }
+log = { workspace = true }
+safetensors = { workspace = true }
+serde = { workspace = true }
+serde_json = { workspace = true }
+
+# Wasm specific crates.
+console_error_panic_hook = "0.1.7"
+wasm-bindgen = "0.2.87"
diff --git a/candle-wasm-examples/segment-anything/build-lib.sh b/candle-wasm-examples/segment-anything/build-lib.sh
new file mode 100644
index 00000000..b0ebb182
--- /dev/null
+++ b/candle-wasm-examples/segment-anything/build-lib.sh
@@ -0,0 +1,2 @@
+cargo build --target wasm32-unknown-unknown --release
+wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web
diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs
new file mode 100644
index 00000000..c4c79fe0
--- /dev/null
+++ b/candle-wasm-examples/segment-anything/src/bin/m.rs
@@ -0,0 +1,113 @@
+use candle::{DType, Device, Tensor};
+use candle_nn::VarBuilder;
+use candle_wasm_example_sam as sam;
+use wasm_bindgen::prelude::*;
+
+#[allow(unused)]
+struct Embeddings {
+ original_width: u32,
+ original_height: u32,
+ width: u32,
+ height: u32,
+ data: Tensor,
+}
+
+#[wasm_bindgen]
+pub struct Model {
+ sam: sam::Sam,
+ embeddings: Option<Embeddings>,
+}
+
+#[wasm_bindgen]
+impl Model {
+ #[wasm_bindgen(constructor)]
+ pub fn new(weights: &[u8], use_tiny: bool) -> Result<Model, JsError> {
+ console_error_panic_hook::set_once();
+ let dev = &Device::Cpu;
+ let weights = safetensors::tensor::SafeTensors::deserialize(weights)?;
+ let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
+ let sam = if use_tiny {
+ sam::Sam::new_tiny(vb)? // tiny vit_t
+ } else {
+ sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
+ };
+ Ok(Self {
+ sam,
+ embeddings: None,
+ })
+ }
+
+ pub fn set_image_embeddings(&mut self, image_data: Vec<u8>) -> Result<(), JsError> {
+ sam::console_log!("image data: {}", image_data.len());
+ let image_data = std::io::Cursor::new(image_data);
+ let image = image::io::Reader::new(image_data)
+ .with_guessed_format()?
+ .decode()
+ .map_err(candle::Error::wrap)?;
+ let (original_height, original_width) = (image.height(), image.width());
+ let (height, width) = (original_height, original_width);
+ let resize_longest = sam::IMAGE_SIZE as u32;
+ let (height, width) = if height < width {
+ let h = (resize_longest * height) / width;
+ (h, resize_longest)
+ } else {
+ let w = (resize_longest * width) / height;
+ (resize_longest, w)
+ };
+ let image_t = {
+ let img = image.resize_exact(width, height, image::imageops::FilterType::CatmullRom);
+ let data = img.to_rgb8().into_raw();
+ Tensor::from_vec(
+ data,
+ (img.height() as usize, img.width() as usize, 3),
+ &Device::Cpu,
+ )?
+ .permute((2, 0, 1))?
+ };
+ let data = self.sam.embeddings(&image_t)?;
+ self.embeddings = Some(Embeddings {
+ original_width,
+ original_height,
+ width,
+ height,
+ data,
+ });
+ Ok(())
+ }
+
+ // x and y have to be between 0 and 1
+ pub fn mask_for_point(&self, x: f64, y: f64) -> Result<String, JsError> {
+ let embeddings = match &self.embeddings {
+ None => todo!(),
+ Some(embeddings) => embeddings,
+ };
+ let (mask, iou_predictions) = self.sam.forward_for_embeddings(
+ &embeddings.data,
+ embeddings.height as usize,
+ embeddings.width as usize,
+ Some((x, y)),
+ false,
+ )?;
+ let iou = iou_predictions.to_vec1::<f32>()?[0];
+ let mask_shape = mask.dims().to_vec();
+ let mask_data = mask.ge(0f32)?.flatten_all()?.to_vec1::<u8>()?;
+ let mask = Mask {
+ iou,
+ mask_shape,
+ mask_data,
+ };
+ let json = serde_json::to_string(&mask)?;
+ Ok(json)
+ }
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+struct Mask {
+ iou: f32,
+ mask_shape: Vec<usize>,
+ mask_data: Vec<u8>,
+}
+
+fn main() {
+ console_error_panic_hook::set_once();
+}
diff --git a/candle-wasm-examples/segment-anything/src/lib.rs b/candle-wasm-examples/segment-anything/src/lib.rs
new file mode 100644
index 00000000..0f4f96fd
--- /dev/null
+++ b/candle-wasm-examples/segment-anything/src/lib.rs
@@ -0,0 +1,19 @@
+use candle_transformers::models::segment_anything::sam;
+use wasm_bindgen::prelude::*;
+
+pub use sam::{Sam, IMAGE_SIZE};
+
+#[wasm_bindgen]
+extern "C" {
+ // Use `js_namespace` here to bind `console.log(..)` instead of just
+ // `log(..)`
+ #[wasm_bindgen(js_namespace = console)]
+ pub fn log(s: &str);
+}
+
+#[macro_export]
+macro_rules! console_log {
+ // Note that this is using the `log` function imported above during
+ // `bare_bones`
+ ($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string()))
+}