summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/segment-anything/main.rs16
-rw-r--r--candle-examples/examples/segment-anything/model_mask_decoder.rs2
-rw-r--r--candle-examples/examples/segment-anything/model_sam.rs103
-rw-r--r--candle-examples/examples/segment-anything/model_transformer.rs6
-rw-r--r--candle-examples/examples/yolo-v3/main.rs4
-rw-r--r--candle-examples/examples/yolo-v8/main.rs12
-rw-r--r--candle-examples/src/object_detection.rs8
7 files changed, 125 insertions, 26 deletions
diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs
index a749ba2a..4627248c 100644
--- a/candle-examples/examples/segment-anything/main.rs
+++ b/candle-examples/examples/segment-anything/main.rs
@@ -188,13 +188,25 @@ pub fn main() -> anyhow::Result<()> {
if args.generate_masks {
// Default options similar to the Python version.
- sam.generate_masks(
+ let bboxes = sam.generate_masks(
&image,
/* points_per_side */ 32,
/* crop_n_layer */ 0,
/* crop_overlap_ratio */ 512. / 1500.,
/* crop_n_points_downscale_factor */ 1,
- )?
+ )?;
+ for (idx, bbox) in bboxes.iter().enumerate() {
+ println!("{bbox:?}");
+ let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
+ let (h, w) = mask.dims2()?;
+ let mask = mask.broadcast_as((3, h, w))?;
+ candle_examples::save_image_resize(
+ &mask,
+ format!("sam_mask{idx}.png"),
+ initial_h,
+ initial_w,
+ )?;
+ }
} else {
let point = Some((args.point_x, args.point_y));
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs
index 1f6d62a4..c02b44a7 100644
--- a/candle-examples/examples/segment-anything/model_mask_decoder.rs
+++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs
@@ -219,7 +219,7 @@ impl MaskDecoder {
let h = mlp.forward(&mask_tokens_out.i((.., i))?)?;
hyper_in_list.push(h)
}
- let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?;
+ let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?.contiguous()?;
let (b, c, h, w) = upscaled_embedding.dims4()?;
let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?;
let masks = masks.reshape((b, (), h, w))?;
diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs
index 884559af..ade976c1 100644
--- a/candle-examples/examples/segment-anything/model_sam.rs
+++ b/candle-examples/examples/segment-anything/model_sam.rs
@@ -8,6 +8,11 @@ use crate::model_prompt_encoder::PromptEncoder;
const PROMPT_EMBED_DIM: usize = 256;
pub const IMAGE_SIZE: usize = 1024;
const VIT_PATCH_SIZE: usize = 16;
+const PRED_IOU_THRESH: f32 = 0.88;
+const STABILITY_SCORE_OFFSET: f32 = 1.0;
+const STABILITY_SCORE_THRESHOLD: f32 = 0.95;
+const MODEL_MASK_THRESHOLD: f32 = 0.0;
+const CROP_NMS_THRESH: f32 = 0.7;
#[derive(Debug)]
pub struct Sam {
@@ -129,7 +134,12 @@ impl Sam {
img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
}
- fn process_crop(&self, img: &Tensor, cb: CropBox, point_grids: &[(f64, f64)]) -> Result<()> {
+ fn process_crop(
+ &self,
+ img: &Tensor,
+ cb: CropBox,
+ point_grids: &[(f64, f64)],
+ ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
// Crop the image and calculate embeddings.
let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
let img = self.preprocess(&img)?.unsqueeze(0)?;
@@ -144,28 +154,86 @@ impl Sam {
.iter()
.map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
.collect::<Vec<_>>();
+
+ let mut bboxes = Vec::new();
for points in points.chunks(64) {
+ // Run the model on this batch.
let points_len = points.len();
let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
self.prompt_encoder
.forward(Some((&in_points, &in_labels)), None, None)?;
- let (_low_res_mask, iou_predictions) = self.mask_decoder.forward(
+
+ let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
&img_embeddings,
&image_pe,
&sparse_prompt_embeddings,
&dense_prompt_embeddings,
/* multimask_output */ true,
)?;
+ let low_res_mask = low_res_mask.flatten(0, 1)?;
+ let iou_predictions = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?;
+ let dev = low_res_mask.device();
+
+ for (i, iou) in iou_predictions.iter().enumerate() {
+ // Filter by predicted IoU.
+ if *iou < PRED_IOU_THRESH {
+ continue;
+ }
+ let low_res_mask = low_res_mask.get(i)?;
+
+ // Calculate stability score.
+ let bound = Tensor::new(MODEL_MASK_THRESHOLD + STABILITY_SCORE_OFFSET, dev)?
+ .broadcast_as(low_res_mask.shape())?;
+ let intersections = low_res_mask
+ .ge(&bound)?
+ .to_dtype(DType::F32)?
+ .sum_all()?
+ .to_vec0::<f32>()?;
+ let bound = Tensor::new(MODEL_MASK_THRESHOLD - STABILITY_SCORE_OFFSET, dev)?
+ .broadcast_as(low_res_mask.shape())?;
+ let unions = low_res_mask
+ .ge(&bound)?
+ .to_dtype(DType::F32)?
+ .sum_all()?
+ .to_vec0::<f32>()?;
+ let stability_score = intersections / unions;
+ if stability_score < STABILITY_SCORE_THRESHOLD {
+ continue;
+ }
- println!("{cb:?} {iou_predictions}");
+ // Threshold masks and calculate boxes.
+ let low_res_mask = low_res_mask
+ .ge(&Tensor::new(0f32, dev)?.broadcast_as(low_res_mask.shape())?)?
+ .to_dtype(DType::U32)?;
+ let low_res_mask_per_x = low_res_mask.sum(0)?.to_vec1::<u32>()?;
+ let low_res_mask_per_y = low_res_mask.sum(1)?.to_vec1::<u32>()?;
+ let min_max_x = min_max_indexes(&low_res_mask_per_x);
+ let min_max_y = min_max_indexes(&low_res_mask_per_y);
+ if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) {
+ let bbox = candle_examples::object_detection::Bbox {
+ xmin: x0 as f32,
+ ymin: y0 as f32,
+ xmax: x1 as f32,
+ ymax: y1 as f32,
+ confidence: *iou,
+ data: low_res_mask,
+ };
+ bboxes.push(bbox);
+ }
+ // TODO:
+ // Filter boxes that touch crop boundaries
+ // Compress to RLE.
+ }
}
+ let mut bboxes = vec![bboxes];
// Remove duplicates within this crop.
+ candle_examples::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);
- // Return to the original image frame.
- Ok(())
+ // TODO: Return to the original image frame.
+ Ok(bboxes.remove(0))
}
pub fn generate_masks(
@@ -175,7 +243,7 @@ impl Sam {
crop_n_layer: usize,
crop_overlap_ratio: f64,
crop_n_points_downscale_factor: usize,
- ) -> Result<()> {
+ ) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
let (_c, h, w) = img.dims3()?;
let point_grids = build_all_layer_point_grids(
points_per_side,
@@ -183,12 +251,31 @@ impl Sam {
crop_n_points_downscale_factor,
);
let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio);
+ let mut bboxes = Vec::new();
for crop_box in crop_boxes.into_iter() {
let layer_idx = crop_box.layer_idx;
- self.process_crop(img, crop_box, &point_grids[layer_idx])?
+ let b = self.process_crop(img, crop_box, &point_grids[layer_idx])?;
+ bboxes.extend(b)
}
// TODO: remove duplicates
- Ok(())
+ Ok(bboxes)
+ }
+}
+
+// Return the first and last indexes i for which values[i] > 0
+fn min_max_indexes(values: &[u32]) -> Option<(usize, usize)> {
+ let (mut min_i, mut max_i) = (usize::MAX, usize::MIN);
+ for (i, &s) in values.iter().enumerate() {
+ if s == 0 {
+ continue;
+ }
+ min_i = usize::min(i, min_i);
+ max_i = usize::max(i, max_i);
+ }
+ if max_i < min_i {
+ None
+ } else {
+ Some((min_i, max_i))
}
}
diff --git a/candle-examples/examples/segment-anything/model_transformer.rs b/candle-examples/examples/segment-anything/model_transformer.rs
index e4de27cb..e12aac08 100644
--- a/candle-examples/examples/segment-anything/model_transformer.rs
+++ b/candle-examples/examples/segment-anything/model_transformer.rs
@@ -45,9 +45,9 @@ impl Attention {
}
fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
- let q = self.q_proj.forward(q)?;
- let k = self.k_proj.forward(k)?;
- let v = self.v_proj.forward(v)?;
+ let q = self.q_proj.forward(&q.contiguous()?)?;
+ let k = self.k_proj.forward(&k.contiguous()?)?;
+ let v = self.v_proj.forward(&v.contiguous()?)?;
let q = self.separate_heads(&q)?;
let k = self.separate_heads(&k)?;
diff --git a/candle-examples/examples/yolo-v3/main.rs b/candle-examples/examples/yolo-v3/main.rs
index 5e388921..20021b45 100644
--- a/candle-examples/examples/yolo-v3/main.rs
+++ b/candle-examples/examples/yolo-v3/main.rs
@@ -46,7 +46,7 @@ pub fn report(
let (npreds, pred_size) = pred.dims2()?;
let nclasses = pred_size - 5;
// The bounding boxes grouped by (maximum) class index.
- let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
+ let mut bboxes: Vec<Vec<Bbox<()>>> = (0..nclasses).map(|_| vec![]).collect();
// Extract the bounding boxes for which confidence is above the threshold.
for index in 0..npreds {
let pred = Vec::<f32>::try_from(pred.get(index)?)?;
@@ -65,7 +65,7 @@ pub fn report(
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
- keypoints: vec![],
+ data: (),
};
bboxes[class_index].push(bbox)
}
diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs
index d5c5ac1c..2017b5be 100644
--- a/candle-examples/examples/yolo-v8/main.rs
+++ b/candle-examples/examples/yolo-v8/main.rs
@@ -64,7 +64,7 @@ pub fn report_detect(
let (pred_size, npreds) = pred.dims2()?;
let nclasses = pred_size - 4;
// The bounding boxes grouped by (maximum) class index.
- let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
+ let mut bboxes: Vec<Vec<Bbox<Vec<KeyPoint>>>> = (0..nclasses).map(|_| vec![]).collect();
// Extract the bounding boxes for which confidence is above the threshold.
for index in 0..npreds {
let pred = Vec::<f32>::try_from(pred.i((.., index))?)?;
@@ -83,7 +83,7 @@ pub fn report_detect(
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
- keypoints: vec![],
+ data: vec![],
};
bboxes[class_index].push(bbox)
}
@@ -176,7 +176,7 @@ pub fn report_pose(
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
- keypoints,
+ data: keypoints,
};
bboxes.push(bbox)
}
@@ -204,7 +204,7 @@ pub fn report_pose(
image::Rgb([255, 0, 0]),
);
}
- for kp in b.keypoints.iter() {
+ for kp in b.data.iter() {
if kp.mask < 0.6 {
continue;
}
@@ -219,8 +219,8 @@ pub fn report_pose(
}
for &(idx1, idx2) in KP_CONNECTIONS.iter() {
- let kp1 = &b.keypoints[idx1];
- let kp2 = &b.keypoints[idx2];
+ let kp1 = &b.data[idx1];
+ let kp2 = &b.data[idx2];
if kp1.mask < 0.6 || kp2.mask < 0.6 {
continue;
}
diff --git a/candle-examples/src/object_detection.rs b/candle-examples/src/object_detection.rs
index c7c60136..ce579316 100644
--- a/candle-examples/src/object_detection.rs
+++ b/candle-examples/src/object_detection.rs
@@ -1,12 +1,12 @@
/// A bounding box around an object.
#[derive(Debug, Clone)]
-pub struct Bbox {
+pub struct Bbox<D> {
pub xmin: f32,
pub ymin: f32,
pub xmax: f32,
pub ymax: f32,
pub confidence: f32,
- pub keypoints: Vec<KeyPoint>,
+ pub data: D,
}
#[derive(Debug, Clone, Copy, PartialEq)]
@@ -17,7 +17,7 @@ pub struct KeyPoint {
}
/// Intersection over union of two bounding boxes.
-pub fn iou(b1: &Bbox, b2: &Bbox) -> f32 {
+pub fn iou<D>(b1: &Bbox<D>, b2: &Bbox<D>) -> f32 {
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
let i_xmin = b1.xmin.max(b2.xmin);
@@ -28,7 +28,7 @@ pub fn iou(b1: &Bbox, b2: &Bbox) -> f32 {
i_area / (b1_area + b2_area - i_area)
}
-pub fn non_maximum_suppression(bboxes: &mut [Vec<Bbox>], threshold: f32) {
+pub fn non_maximum_suppression<D>(bboxes: &mut [Vec<Bbox<D>>], threshold: f32) {
// Perform non-maximum suppression.
for bboxes_for_class in bboxes.iter_mut() {
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());