summaryrefslogtreecommitdiff
path: root/candle-wasm-examples/segment-anything/src
diff options
context:
space:
mode:
authorlichin-lin <vic20087cjimlin@gmail.com>2023-10-01 18:25:22 +0100
committerGitHub <noreply@github.com>2023-10-01 18:25:22 +0100
commit41143db1afd47b1037b29d44d9a8eebe44e5b508 (patch)
treec1d3e0195f2ffdc1f427d3a896c01981ab59bc2b /candle-wasm-examples/segment-anything/src
parent096dee7073e960f4b845a430b889a9fb2f2f0c78 (diff)
downloadcandle-41143db1afd47b1037b29d44d9a8eebe44e5b508.tar.gz
candle-41143db1afd47b1037b29d44d9a8eebe44e5b508.tar.bz2
candle-41143db1afd47b1037b29d44d9a8eebe44e5b508.zip
[segment-anything] add multi point logic for demo site (#1002)
* [segment-anything] add multi point logic for demo site * [segment-anything] remove libs and update functions
Diffstat (limited to 'candle-wasm-examples/segment-anything/src')
-rw-r--r--candle-wasm-examples/segment-anything/src/bin/m.rs36
1 files changed, 24 insertions, 12 deletions
diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs
index 12349493..2be59adc 100644
--- a/candle-wasm-examples/segment-anything/src/bin/m.rs
+++ b/candle-wasm-examples/segment-anything/src/bin/m.rs
@@ -74,17 +74,24 @@ impl Model {
Ok(())
}
- // x and y have to be between 0 and 1
- pub fn mask_for_point(&self, x: f64, y: f64) -> Result<JsValue, JsError> {
- if !(0. ..=1.).contains(&x) {
- Err(JsError::new(&format!(
- "x has to be between 0 and 1, got {x}"
- )))?
- }
- if !(0. ..=1.).contains(&y) {
- Err(JsError::new(&format!(
- "y has to be between 0 and 1, got {y}"
- )))?
+ pub fn mask_for_point(&self, input: JsValue) -> Result<JsValue, JsError> {
+ let input: PointsInput =
+ serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
+ let transformed_points = input.points;
+
+ for &(x, y, _bool) in &transformed_points {
+ if !(0.0..=1.0).contains(&x) {
+ return Err(JsError::new(&format!(
+ "x has to be between 0 and 1, got {}",
+ x
+ )));
+ }
+ if !(0.0..=1.0).contains(&y) {
+ return Err(JsError::new(&format!(
+ "y has to be between 0 and 1, got {}",
+ y
+ )));
+ }
}
let embeddings = match &self.embeddings {
None => Err(JsError::new("image embeddings have not been set"))?,
@@ -94,7 +101,7 @@ impl Model {
&embeddings.data,
embeddings.height as usize,
embeddings.width as usize,
- &[(x, y, true)],
+ &transformed_points,
false,
)?;
let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0];
@@ -134,6 +141,11 @@ struct MaskImage {
image: Image,
}
+#[derive(serde::Serialize, serde::Deserialize)]
+struct PointsInput {
+ points: Vec<(f64, f64, bool)>,
+}
+
fn main() {
console_error_panic_hook::set_once();
}