summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml1
-rw-r--r--candle-transformers/src/models/segment_anything/sam.rs28
-rw-r--r--candle-wasm-examples/segment-anything/Cargo.toml29
-rw-r--r--candle-wasm-examples/segment-anything/build-lib.sh2
-rw-r--r--candle-wasm-examples/segment-anything/src/bin/m.rs113
-rw-r--r--candle-wasm-examples/segment-anything/src/lib.rs19
6 files changed, 189 insertions, 3 deletions
diff --git a/Cargo.toml b/Cargo.toml
index ce41876a..b45a2ab6 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -8,6 +8,7 @@ members = [
"candle-pyo3",
"candle-transformers",
"candle-wasm-examples/llama2-c",
+ "candle-wasm-examples/segment-anything",
"candle-wasm-examples/whisper",
"candle-wasm-examples/yolo",
]
diff --git a/candle-transformers/src/models/segment_anything/sam.rs b/candle-transformers/src/models/segment_anything/sam.rs
index c40473e3..92756591 100644
--- a/candle-transformers/src/models/segment_anything/sam.rs
+++ b/candle-transformers/src/models/segment_anything/sam.rs
@@ -122,6 +122,11 @@ impl Sam {
})
}
+ pub fn embeddings(&self, img: &Tensor) -> Result<Tensor> {
+ let img = self.preprocess(img)?.unsqueeze(0)?;
+ self.image_encoder.forward(&img)
+ }
+
pub fn forward(
&self,
img: &Tensor,
@@ -131,15 +136,32 @@ impl Sam {
let (_c, original_h, original_w) = img.dims3()?;
let img = self.preprocess(img)?.unsqueeze(0)?;
let img_embeddings = self.image_encoder.forward(&img)?;
+ self.forward_for_embeddings(
+ &img_embeddings,
+ original_h,
+ original_w,
+ point,
+ multimask_output,
+ )
+ }
+
+ pub fn forward_for_embeddings(
+ &self,
+ img_embeddings: &Tensor,
+ original_h: usize,
+ original_w: usize,
+ point: Option<(f64, f64)>,
+ multimask_output: bool,
+ ) -> Result<(Tensor, Tensor)> {
let image_pe = self.prompt_encoder.get_dense_pe()?;
let points = match point {
None => None,
Some((x, y)) => {
let points = Tensor::new(
&[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
- img.device(),
+ img_embeddings.device(),
)?;
- let labels = Tensor::ones((1, 1), DType::F32, img.device())?;
+ let labels = Tensor::ones((1, 1), DType::F32, img_embeddings.device())?;
Some((points, labels))
}
};
@@ -147,7 +169,7 @@ impl Sam {
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
self.prompt_encoder.forward(points, None, None)?;
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
- &img_embeddings,
+ img_embeddings,
&image_pe,
&sparse_prompt_embeddings,
&dense_prompt_embeddings,
diff --git a/candle-wasm-examples/segment-anything/Cargo.toml b/candle-wasm-examples/segment-anything/Cargo.toml
new file mode 100644
index 00000000..ab82ab1f
--- /dev/null
+++ b/candle-wasm-examples/segment-anything/Cargo.toml
@@ -0,0 +1,29 @@
+[package]
+name = "candle-wasm-example-sam"
+version.workspace = true
+edition.workspace = true
+description.workspace = true
+repository.workspace = true
+keywords.workspace = true
+categories.workspace = true
+license.workspace = true
+
+[dependencies]
+candle = { path = "../../candle-core", version = "0.2.1", package = "candle-core" }
+candle-nn = { path = "../../candle-nn", version = "0.2.1" }
+candle-transformers = { path = "../../candle-transformers", version = "0.2.1" }
+num-traits = { workspace = true }
+
+# App crates.
+anyhow = { workspace = true }
+byteorder = { workspace = true }
+getrandom = { version = "0.2", features = ["js"] }
+image = { workspace = true }
+log = { workspace = true }
+safetensors = { workspace = true }
+serde = { workspace = true }
+serde_json = { workspace = true }
+
+# Wasm specific crates.
+console_error_panic_hook = "0.1.7"
+wasm-bindgen = "0.2.87"
diff --git a/candle-wasm-examples/segment-anything/build-lib.sh b/candle-wasm-examples/segment-anything/build-lib.sh
new file mode 100644
index 00000000..b0ebb182
--- /dev/null
+++ b/candle-wasm-examples/segment-anything/build-lib.sh
@@ -0,0 +1,2 @@
+cargo build --target wasm32-unknown-unknown --release
+wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web
diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs
new file mode 100644
index 00000000..c4c79fe0
--- /dev/null
+++ b/candle-wasm-examples/segment-anything/src/bin/m.rs
@@ -0,0 +1,113 @@
+use candle::{DType, Device, Tensor};
+use candle_nn::VarBuilder;
+use candle_wasm_example_sam as sam;
+use wasm_bindgen::prelude::*;
+
+#[allow(unused)]
+struct Embeddings {
+ original_width: u32,
+ original_height: u32,
+ width: u32,
+ height: u32,
+ data: Tensor,
+}
+
+#[wasm_bindgen]
+pub struct Model {
+ sam: sam::Sam,
+ embeddings: Option<Embeddings>,
+}
+
+#[wasm_bindgen]
+impl Model {
+ #[wasm_bindgen(constructor)]
+ pub fn new(weights: &[u8], use_tiny: bool) -> Result<Model, JsError> {
+ console_error_panic_hook::set_once();
+ let dev = &Device::Cpu;
+ let weights = safetensors::tensor::SafeTensors::deserialize(weights)?;
+ let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev);
+ let sam = if use_tiny {
+ sam::Sam::new_tiny(vb)? // tiny vit_t
+ } else {
+ sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
+ };
+ Ok(Self {
+ sam,
+ embeddings: None,
+ })
+ }
+
+ pub fn set_image_embeddings(&mut self, image_data: Vec<u8>) -> Result<(), JsError> {
+ sam::console_log!("image data: {}", image_data.len());
+ let image_data = std::io::Cursor::new(image_data);
+ let image = image::io::Reader::new(image_data)
+ .with_guessed_format()?
+ .decode()
+ .map_err(candle::Error::wrap)?;
+ let (original_height, original_width) = (image.height(), image.width());
+ let (height, width) = (original_height, original_width);
+ let resize_longest = sam::IMAGE_SIZE as u32;
+ let (height, width) = if height < width {
+ let h = (resize_longest * height) / width;
+ (h, resize_longest)
+ } else {
+ let w = (resize_longest * width) / height;
+ (resize_longest, w)
+ };
+ let image_t = {
+ let img = image.resize_exact(width, height, image::imageops::FilterType::CatmullRom);
+ let data = img.to_rgb8().into_raw();
+ Tensor::from_vec(
+ data,
+ (img.height() as usize, img.width() as usize, 3),
+ &Device::Cpu,
+ )?
+ .permute((2, 0, 1))?
+ };
+ let data = self.sam.embeddings(&image_t)?;
+ self.embeddings = Some(Embeddings {
+ original_width,
+ original_height,
+ width,
+ height,
+ data,
+ });
+ Ok(())
+ }
+
+ // x and y have to be between 0 and 1
+ pub fn mask_for_point(&self, x: f64, y: f64) -> Result<String, JsError> {
+ let embeddings = match &self.embeddings {
+ None => todo!(),
+ Some(embeddings) => embeddings,
+ };
+ let (mask, iou_predictions) = self.sam.forward_for_embeddings(
+ &embeddings.data,
+ embeddings.height as usize,
+ embeddings.width as usize,
+ Some((x, y)),
+ false,
+ )?;
+ let iou = iou_predictions.to_vec1::<f32>()?[0];
+ let mask_shape = mask.dims().to_vec();
+ let mask_data = mask.ge(0f32)?.flatten_all()?.to_vec1::<u8>()?;
+ let mask = Mask {
+ iou,
+ mask_shape,
+ mask_data,
+ };
+ let json = serde_json::to_string(&mask)?;
+ Ok(json)
+ }
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+struct Mask {
+ iou: f32,
+ mask_shape: Vec<usize>,
+ mask_data: Vec<u8>,
+}
+
+fn main() {
+ console_error_panic_hook::set_once();
+}
diff --git a/candle-wasm-examples/segment-anything/src/lib.rs b/candle-wasm-examples/segment-anything/src/lib.rs
new file mode 100644
index 00000000..0f4f96fd
--- /dev/null
+++ b/candle-wasm-examples/segment-anything/src/lib.rs
@@ -0,0 +1,19 @@
+use candle_transformers::models::segment_anything::sam;
+use wasm_bindgen::prelude::*;
+
+pub use sam::{Sam, IMAGE_SIZE};
+
+#[wasm_bindgen]
+extern "C" {
+ // Use `js_namespace` here to bind `console.log(..)` instead of just
+ // `log(..)`
+ #[wasm_bindgen(js_namespace = console)]
+ pub fn log(s: &str);
+}
+
+#[macro_export]
+macro_rules! console_log {
+ // Note that this is using the `log` function imported above during
+ // `bare_bones`
+ ($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string()))
+}