diff options
author | lichin-lin <vic20087cjimlin@gmail.com> | 2023-10-01 18:25:22 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-01 18:25:22 +0100 |
commit | 41143db1afd47b1037b29d44d9a8eebe44e5b508 (patch) | |
tree | c1d3e0195f2ffdc1f427d3a896c01981ab59bc2b /candle-wasm-examples/segment-anything/src | |
parent | 096dee7073e960f4b845a430b889a9fb2f2f0c78 (diff) | |
download | candle-41143db1afd47b1037b29d44d9a8eebe44e5b508.tar.gz candle-41143db1afd47b1037b29d44d9a8eebe44e5b508.tar.bz2 candle-41143db1afd47b1037b29d44d9a8eebe44e5b508.zip |
[segment-anything] add multi point logic for demo site (#1002)
* [segment-anything] add multi point logic for demo site
* [segment-anything] remove libs and update functions
Diffstat (limited to 'candle-wasm-examples/segment-anything/src')
-rw-r--r-- | candle-wasm-examples/segment-anything/src/bin/m.rs | 36 |
1 files changed, 24 insertions, 12 deletions
diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs index 12349493..2be59adc 100644 --- a/candle-wasm-examples/segment-anything/src/bin/m.rs +++ b/candle-wasm-examples/segment-anything/src/bin/m.rs @@ -74,17 +74,24 @@ impl Model { Ok(()) } - // x and y have to be between 0 and 1 - pub fn mask_for_point(&self, x: f64, y: f64) -> Result<JsValue, JsError> { - if !(0. ..=1.).contains(&x) { - Err(JsError::new(&format!( - "x has to be between 0 and 1, got {x}" - )))? - } - if !(0. ..=1.).contains(&y) { - Err(JsError::new(&format!( - "y has to be between 0 and 1, got {y}" - )))? + pub fn mask_for_point(&self, input: JsValue) -> Result<JsValue, JsError> { + let input: PointsInput = + serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; + let transformed_points = input.points; + + for &(x, y, _bool) in &transformed_points { + if !(0.0..=1.0).contains(&x) { + return Err(JsError::new(&format!( + "x has to be between 0 and 1, got {}", + x + ))); + } + if !(0.0..=1.0).contains(&y) { + return Err(JsError::new(&format!( + "y has to be between 0 and 1, got {}", + y + ))); + } } let embeddings = match &self.embeddings { None => Err(JsError::new("image embeddings have not been set"))?, @@ -94,7 +101,7 @@ impl Model { &embeddings.data, embeddings.height as usize, embeddings.width as usize, - &[(x, y, true)], + &transformed_points, false, )?; let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0]; @@ -134,6 +141,11 @@ struct MaskImage { image: Image, } +#[derive(serde::Serialize, serde::Deserialize)] +struct PointsInput { + points: Vec<(f64, f64, bool)>, +} + fn main() { console_error_panic_hook::set_once(); } |