summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorMatthew O'Malley-Nichols <91226873+onichmath@users.noreply.github.com>2024-08-09 22:57:52 -0700
committerGitHub <noreply@github.com>2024-08-10 07:57:52 +0200
commit14db029494171268600914995369e0962f48c29a (patch)
tree1ba49badf6f4ae8a5a57f205315232476f3577de /candle-transformers
parent6e6c1c99b09de707d1f0aa6839be70202b278a57 (diff)
downloadcandle-14db029494171268600914995369e0962f48c29a.tar.gz
candle-14db029494171268600914995369e0962f48c29a.tar.bz2
candle-14db029494171268600914995369e0962f48c29a.zip
Soft Non-Maximum Suppression (#2400)
* Soft NMS with thresholds * NMS Test * Soft nms w/ boxes removed below threshold * Soft nms test * No longer removing bounding boxes to fit Soft-NMS focus * Initialize confidence * Added comments * Refactored out updating based on IOU/sigma * Score_threshold -> confidence_threshold for clarity * Remove bboxes below confidence threshold * Softnms basic functionality test * Softnms confidence decay test * Softnms confidence threshold test * Softnms no overlapping bbox test * Testing confidence after no overlap test * Single bbox and no bbox tests * Signify test completion * Handling result of test functions * Checking all pairs of bboxes instead of a forward pass * Equal confidence overlap test * Clarified tests for implementation * No longer dropping boxes, just setting to 0.0 * Formatted w/ cargo
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/object_detection.rs58
-rw-r--r--candle-transformers/tests/nms_tests.rs222
2 files changed, 280 insertions, 0 deletions
diff --git a/candle-transformers/src/object_detection.rs b/candle-transformers/src/object_detection.rs
index ce579316..e922075f 100644
--- a/candle-transformers/src/object_detection.rs
+++ b/candle-transformers/src/object_detection.rs
@@ -50,3 +50,61 @@ pub fn non_maximum_suppression<D>(bboxes: &mut [Vec<Bbox<D>>], threshold: f32) {
bboxes_for_class.truncate(current_index);
}
}
+
+// Updates confidences starting at highest and comparing subsequent boxes.
+fn update_confidences<D>(
+ bboxes_for_class: &[Bbox<D>],
+ updated_confidences: &mut [f32],
+ iou_threshold: f32,
+ sigma: f32,
+) {
+ let len = bboxes_for_class.len();
+ for current_index in 0..len {
+ let current_bbox = &bboxes_for_class[current_index];
+ for index in (current_index + 1)..len {
+ let iou_val = iou(current_bbox, &bboxes_for_class[index]);
+ if iou_val > iou_threshold {
+ // Decay calculation from page 4 of: https://arxiv.org/pdf/1704.04503
+ let decay = (-iou_val * iou_val / sigma).exp();
+ let updated_confidence = bboxes_for_class[index].confidence * decay;
+ updated_confidences[index] = updated_confidence;
+ }
+ }
+ }
+}
+
+// Sorts the bounding boxes by confidence and applies soft non-maximum suppression.
+// This function is based on the algorithm described in https://arxiv.org/pdf/1704.04503
+pub fn soft_non_maximum_suppression<D>(
+ bboxes: &mut [Vec<Bbox<D>>],
+ iou_threshold: Option<f32>,
+ confidence_threshold: Option<f32>,
+ sigma: Option<f32>,
+) {
+ let iou_threshold = iou_threshold.unwrap_or(0.5);
+ let confidence_threshold = confidence_threshold.unwrap_or(0.1);
+ let sigma = sigma.unwrap_or(0.5);
+
+ for bboxes_for_class in bboxes.iter_mut() {
+ // Sort boxes by confidence in descending order
+ bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
+ let mut updated_confidences = bboxes_for_class
+ .iter()
+ .map(|bbox| bbox.confidence)
+ .collect::<Vec<_>>();
+ update_confidences(
+ bboxes_for_class,
+ &mut updated_confidences,
+ iou_threshold,
+ sigma,
+ );
+ // Update confidences, set to 0.0 if below threshold
+ for (i, &confidence) in updated_confidences.iter().enumerate() {
+ bboxes_for_class[i].confidence = if confidence < confidence_threshold {
+ 0.0
+ } else {
+ confidence
+ };
+ }
+ }
+}
diff --git a/candle-transformers/tests/nms_tests.rs b/candle-transformers/tests/nms_tests.rs
new file mode 100644
index 00000000..d70f6fdf
--- /dev/null
+++ b/candle-transformers/tests/nms_tests.rs
@@ -0,0 +1,222 @@
+use candle::Result;
+use candle_transformers::object_detection::{
+ non_maximum_suppression, soft_non_maximum_suppression, Bbox,
+};
+
+#[test]
+fn nms_basic() -> Result<()> {
+ // Boxes based upon https://thepythoncode.com/article/non-maximum-suppression-using-opencv-in-python
+ let mut bboxes = vec![vec![
+ Bbox {
+ xmin: 245.0,
+ ymin: 305.0,
+ xmax: 575.0,
+ ymax: 490.0,
+ confidence: 0.9,
+ data: (),
+ }, // Box 1
+ Bbox {
+ xmin: 235.0,
+ ymin: 300.0,
+ xmax: 485.0,
+ ymax: 515.0,
+ confidence: 0.8,
+ data: (),
+ }, // Box 2
+ Bbox {
+ xmin: 305.0,
+ ymin: 270.0,
+ xmax: 540.0,
+ ymax: 500.0,
+ confidence: 0.6,
+ data: (),
+ }, // Box 3
+ ]];
+
+ non_maximum_suppression(&mut bboxes, 0.5);
+ let bboxes = bboxes.into_iter().next().unwrap();
+ assert_eq!(bboxes.len(), 1);
+ assert_eq!(bboxes[0].confidence, 0.9);
+
+ Ok(())
+}
+
+#[test]
+fn softnms_basic_functionality() -> Result<()> {
+ let mut bboxes = vec![vec![
+ Bbox {
+ xmin: 0.0,
+ ymin: 0.0,
+ xmax: 1.0,
+ ymax: 1.0,
+ confidence: 0.5,
+ data: (),
+ },
+ Bbox {
+ xmin: 0.1,
+ ymin: 0.1,
+ xmax: 1.1,
+ ymax: 1.1,
+ confidence: 0.9,
+ data: (),
+ },
+ Bbox {
+ xmin: 0.2,
+ ymin: 0.2,
+ xmax: 1.2,
+ ymax: 1.2,
+ confidence: 0.6,
+ data: (),
+ },
+ ]];
+
+ soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
+
+ // Should decay boxes following highest confidence box
+ assert!(bboxes[0][0].confidence == 0.9);
+ assert!(bboxes[0][1].confidence < 0.5);
+ assert!(bboxes[0][2].confidence < 0.6);
+ Ok(())
+}
+
+#[test]
+fn softnms_confidence_decay() -> Result<()> {
+ let mut bboxes = vec![vec![
+ Bbox {
+ xmin: 0.0,
+ ymin: 0.0,
+ xmax: 1.0,
+ ymax: 1.0,
+ confidence: 0.9,
+ data: (),
+ }, // Reference box
+ Bbox {
+ xmin: 0.1,
+ ymin: 0.1,
+ xmax: 1.1,
+ ymax: 1.1,
+ confidence: 0.8,
+ data: (),
+ }, // Overlapping box
+ ]];
+
+ soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
+
+ // Check that confidence of the overlapping box is decayed
+ assert!(bboxes[0][0].confidence == 0.9);
+ assert!(bboxes[0][1].confidence < 0.8);
+ Ok(())
+}
+
+#[test]
+fn softnms_confidence_threshold() -> Result<()> {
+ let mut bboxes = vec![vec![
+ Bbox {
+ xmin: 0.0,
+ ymin: 0.0,
+ xmax: 1.0,
+ ymax: 1.0,
+ confidence: 0.9,
+ data: (),
+ },
+ Bbox {
+ xmin: 0.1,
+ ymin: 0.1,
+ xmax: 1.1,
+ ymax: 1.1,
+ confidence: 0.05,
+ data: (),
+ },
+ ]];
+
+ soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
+
+ // Box with confidence below the threshold should be removed
+ assert_eq!(bboxes[0].len(), 2);
+ assert_eq!(bboxes[0][0].confidence, 0.9);
+ assert_eq!(bboxes[0][1].confidence, 0.00);
+ Ok(())
+}
+
+#[test]
+fn softnms_no_overlap() -> Result<()> {
+ let mut bboxes = vec![vec![
+ Bbox {
+ xmin: 0.0,
+ ymin: 0.0,
+ xmax: 1.0,
+ ymax: 1.0,
+ confidence: 0.9,
+ data: (),
+ },
+ Bbox {
+ xmin: 2.0,
+ ymin: 2.0,
+ xmax: 3.0,
+ ymax: 3.0,
+ confidence: 0.8,
+ data: (),
+ },
+ ]];
+
+ soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
+
+ // Both boxes should remain as they do not significantly overlap
+ assert_eq!(bboxes[0].len(), 2);
+ assert_eq!(bboxes[0][0].confidence, 0.9);
+ assert_eq!(bboxes[0][1].confidence, 0.8);
+ Ok(())
+}
+#[test]
+fn softnms_no_bbox() -> Result<()> {
+ let mut bboxes: Vec<Vec<Bbox<()>>> = vec![];
+ soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
+ assert!(bboxes.is_empty());
+ Ok(())
+}
+
+#[test]
+fn softnms_single_bbox() -> Result<()> {
+ let mut bboxes = vec![vec![Bbox {
+ xmin: 0.0,
+ ymin: 0.0,
+ xmax: 1.0,
+ ymax: 1.0,
+ confidence: 0.9,
+ data: (),
+ }]];
+ soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
+ assert_eq!(bboxes[0].len(), 1);
+ Ok(())
+}
+
+#[test]
+fn softnms_equal_confidence_overlap() -> Result<()> {
+ let mut bboxes = vec![vec![
+ Bbox {
+ xmin: 0.0,
+ ymin: 0.0,
+ xmax: 1.0,
+ ymax: 1.0,
+ confidence: 0.5,
+ data: (),
+ },
+ Bbox {
+ xmin: 0.1,
+ ymin: 0.1,
+ xmax: 1.1,
+ ymax: 1.1,
+ confidence: 0.5,
+ data: (),
+ },
+ ]];
+
+ soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
+
+ // First box will be reference box, second box should be decayed
+ // Implementation must change to have both be decayed
+ assert_eq!(bboxes[0].len(), 2);
+ assert!(bboxes[0][0].confidence == 0.5);
+ assert!(bboxes[0][1].confidence < 0.5);
+ Ok(())
+}