diff options
Diffstat (limited to 'candle-wasm-examples/segment-anything/src/bin/m.rs')
-rw-r--r-- | candle-wasm-examples/segment-anything/src/bin/m.rs | 113 |
1 files changed, 113 insertions, 0 deletions
diff --git a/candle-wasm-examples/segment-anything/src/bin/m.rs b/candle-wasm-examples/segment-anything/src/bin/m.rs new file mode 100644 index 00000000..c4c79fe0 --- /dev/null +++ b/candle-wasm-examples/segment-anything/src/bin/m.rs @@ -0,0 +1,113 @@ +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_wasm_example_sam as sam; +use wasm_bindgen::prelude::*; + +#[allow(unused)] +struct Embeddings { + original_width: u32, + original_height: u32, + width: u32, + height: u32, + data: Tensor, +} + +#[wasm_bindgen] +pub struct Model { + sam: sam::Sam, + embeddings: Option<Embeddings>, +} + +#[wasm_bindgen] +impl Model { + #[wasm_bindgen(constructor)] + pub fn new(weights: &[u8], use_tiny: bool) -> Result<Model, JsError> { + console_error_panic_hook::set_once(); + let dev = &Device::Cpu; + let weights = safetensors::tensor::SafeTensors::deserialize(weights)?; + let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, dev); + let sam = if use_tiny { + sam::Sam::new_tiny(vb)? // tiny vit_t + } else { + sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b + }; + Ok(Self { + sam, + embeddings: None, + }) + } + + pub fn set_image_embeddings(&mut self, image_data: Vec<u8>) -> Result<(), JsError> { + sam::console_log!("image data: {}", image_data.len()); + let image_data = std::io::Cursor::new(image_data); + let image = image::io::Reader::new(image_data) + .with_guessed_format()? + .decode() + .map_err(candle::Error::wrap)?; + let (original_height, original_width) = (image.height(), image.width()); + let (height, width) = (original_height, original_width); + let resize_longest = sam::IMAGE_SIZE as u32; + let (height, width) = if height < width { + let h = (resize_longest * height) / width; + (h, resize_longest) + } else { + let w = (resize_longest * width) / height; + (resize_longest, w) + }; + let image_t = { + let img = image.resize_exact(width, height, image::imageops::FilterType::CatmullRom); + let data = img.to_rgb8().into_raw(); + Tensor::from_vec( + data, + (img.height() as usize, img.width() as usize, 3), + &Device::Cpu, + )? + .permute((2, 0, 1))? + }; + let data = self.sam.embeddings(&image_t)?; + self.embeddings = Some(Embeddings { + original_width, + original_height, + width, + height, + data, + }); + Ok(()) + } + + // x and y have to be between 0 and 1 + pub fn mask_for_point(&self, x: f64, y: f64) -> Result<String, JsError> { + let embeddings = match &self.embeddings { + None => todo!(), + Some(embeddings) => embeddings, + }; + let (mask, iou_predictions) = self.sam.forward_for_embeddings( + &embeddings.data, + embeddings.height as usize, + embeddings.width as usize, + Some((x, y)), + false, + )?; + let iou = iou_predictions.to_vec1::<f32>()?[0]; + let mask_shape = mask.dims().to_vec(); + let mask_data = mask.ge(0f32)?.flatten_all()?.to_vec1::<u8>()?; + let mask = Mask { + iou, + mask_shape, + mask_data, + }; + let json = serde_json::to_string(&mask)?; + Ok(json) + } +} + +#[derive(serde::Serialize, serde::Deserialize)] +struct Mask { + iou: f32, + mask_shape: Vec<usize>, + mask_data: Vec<u8>, +} + +fn main() { + console_error_panic_hook::set_once(); +} |