summaryrefslogtreecommitdiff
path: root/candle-transformers/tests
diff options
context:
space:
mode:
authorMatthew O'Malley-Nichols <91226873+onichmath@users.noreply.github.com>2024-08-09 22:57:52 -0700
committerGitHub <noreply@github.com>2024-08-10 07:57:52 +0200
commit14db029494171268600914995369e0962f48c29a (patch)
tree1ba49badf6f4ae8a5a57f205315232476f3577de /candle-transformers/tests
parent6e6c1c99b09de707d1f0aa6839be70202b278a57 (diff)
downloadcandle-14db029494171268600914995369e0962f48c29a.tar.gz
candle-14db029494171268600914995369e0962f48c29a.tar.bz2
candle-14db029494171268600914995369e0962f48c29a.zip
Soft Non-Maximum Suppression (#2400)
* Soft NMS with thresholds * NMS Test * Soft nms w/ boxes removed below threshold * Soft nms test * No longer removing bounding boxes to fit Soft-NMS focus * Initialize confidence * Added comments * Refactored out updating based on IOU/sigma * Score_threshold -> confidence_threshold for clarity * Remove bboxes below confidence threshold * Softnms basic functionality test * Softnms confidence decay test * Softnms confidence threshold test * Softnms no overlapping bbox test * Testing confidence after no overlap test * Single bbox and no bbox tests * Signify test completion * Handling result of test functions * Checking all pairs of bboxes instead of a forward pass * Equal confidence overlap test * Clarified tests for implementation * No longer dropping boxes, just setting to 0.0 * Formatted w/ cargo
Diffstat (limited to 'candle-transformers/tests')
-rw-r--r--candle-transformers/tests/nms_tests.rs222
1 files changed, 222 insertions, 0 deletions
diff --git a/candle-transformers/tests/nms_tests.rs b/candle-transformers/tests/nms_tests.rs
new file mode 100644
index 00000000..d70f6fdf
--- /dev/null
+++ b/candle-transformers/tests/nms_tests.rs
@@ -0,0 +1,222 @@
+use candle::Result;
+use candle_transformers::object_detection::{
+ non_maximum_suppression, soft_non_maximum_suppression, Bbox,
+};
+
+#[test]
+fn nms_basic() -> Result<()> {
+ // Boxes based upon https://thepythoncode.com/article/non-maximum-suppression-using-opencv-in-python
+ let mut bboxes = vec![vec![
+ Bbox {
+ xmin: 245.0,
+ ymin: 305.0,
+ xmax: 575.0,
+ ymax: 490.0,
+ confidence: 0.9,
+ data: (),
+ }, // Box 1
+ Bbox {
+ xmin: 235.0,
+ ymin: 300.0,
+ xmax: 485.0,
+ ymax: 515.0,
+ confidence: 0.8,
+ data: (),
+ }, // Box 2
+ Bbox {
+ xmin: 305.0,
+ ymin: 270.0,
+ xmax: 540.0,
+ ymax: 500.0,
+ confidence: 0.6,
+ data: (),
+ }, // Box 3
+ ]];
+
+ non_maximum_suppression(&mut bboxes, 0.5);
+ let bboxes = bboxes.into_iter().next().unwrap();
+ assert_eq!(bboxes.len(), 1);
+ assert_eq!(bboxes[0].confidence, 0.9);
+
+ Ok(())
+}
+
+#[test]
+fn softnms_basic_functionality() -> Result<()> {
+ let mut bboxes = vec![vec![
+ Bbox {
+ xmin: 0.0,
+ ymin: 0.0,
+ xmax: 1.0,
+ ymax: 1.0,
+ confidence: 0.5,
+ data: (),
+ },
+ Bbox {
+ xmin: 0.1,
+ ymin: 0.1,
+ xmax: 1.1,
+ ymax: 1.1,
+ confidence: 0.9,
+ data: (),
+ },
+ Bbox {
+ xmin: 0.2,
+ ymin: 0.2,
+ xmax: 1.2,
+ ymax: 1.2,
+ confidence: 0.6,
+ data: (),
+ },
+ ]];
+
+ soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
+
+ // Should decay boxes following highest confidence box
+ assert!(bboxes[0][0].confidence == 0.9);
+ assert!(bboxes[0][1].confidence < 0.5);
+ assert!(bboxes[0][2].confidence < 0.6);
+ Ok(())
+}
+
+#[test]
+fn softnms_confidence_decay() -> Result<()> {
+ let mut bboxes = vec![vec![
+ Bbox {
+ xmin: 0.0,
+ ymin: 0.0,
+ xmax: 1.0,
+ ymax: 1.0,
+ confidence: 0.9,
+ data: (),
+ }, // Reference box
+ Bbox {
+ xmin: 0.1,
+ ymin: 0.1,
+ xmax: 1.1,
+ ymax: 1.1,
+ confidence: 0.8,
+ data: (),
+ }, // Overlapping box
+ ]];
+
+ soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
+
+ // Check that confidence of the overlapping box is decayed
+ assert!(bboxes[0][0].confidence == 0.9);
+ assert!(bboxes[0][1].confidence < 0.8);
+ Ok(())
+}
+
+#[test]
+fn softnms_confidence_threshold() -> Result<()> {
+ let mut bboxes = vec![vec![
+ Bbox {
+ xmin: 0.0,
+ ymin: 0.0,
+ xmax: 1.0,
+ ymax: 1.0,
+ confidence: 0.9,
+ data: (),
+ },
+ Bbox {
+ xmin: 0.1,
+ ymin: 0.1,
+ xmax: 1.1,
+ ymax: 1.1,
+ confidence: 0.05,
+ data: (),
+ },
+ ]];
+
+ soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
+
+ // Box with confidence below the threshold should be removed
+ assert_eq!(bboxes[0].len(), 2);
+ assert_eq!(bboxes[0][0].confidence, 0.9);
+ assert_eq!(bboxes[0][1].confidence, 0.00);
+ Ok(())
+}
+
+#[test]
+fn softnms_no_overlap() -> Result<()> {
+ let mut bboxes = vec![vec![
+ Bbox {
+ xmin: 0.0,
+ ymin: 0.0,
+ xmax: 1.0,
+ ymax: 1.0,
+ confidence: 0.9,
+ data: (),
+ },
+ Bbox {
+ xmin: 2.0,
+ ymin: 2.0,
+ xmax: 3.0,
+ ymax: 3.0,
+ confidence: 0.8,
+ data: (),
+ },
+ ]];
+
+ soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
+
+ // Both boxes should remain as they do not significantly overlap
+ assert_eq!(bboxes[0].len(), 2);
+ assert_eq!(bboxes[0][0].confidence, 0.9);
+ assert_eq!(bboxes[0][1].confidence, 0.8);
+ Ok(())
+}
+#[test]
+fn softnms_no_bbox() -> Result<()> {
+ let mut bboxes: Vec<Vec<Bbox<()>>> = vec![];
+ soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
+ assert!(bboxes.is_empty());
+ Ok(())
+}
+
+#[test]
+fn softnms_single_bbox() -> Result<()> {
+ let mut bboxes = vec![vec![Bbox {
+ xmin: 0.0,
+ ymin: 0.0,
+ xmax: 1.0,
+ ymax: 1.0,
+ confidence: 0.9,
+ data: (),
+ }]];
+ soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
+ assert_eq!(bboxes[0].len(), 1);
+ Ok(())
+}
+
+#[test]
+fn softnms_equal_confidence_overlap() -> Result<()> {
+ let mut bboxes = vec![vec![
+ Bbox {
+ xmin: 0.0,
+ ymin: 0.0,
+ xmax: 1.0,
+ ymax: 1.0,
+ confidence: 0.5,
+ data: (),
+ },
+ Bbox {
+ xmin: 0.1,
+ ymin: 0.1,
+ xmax: 1.1,
+ ymax: 1.1,
+ confidence: 0.5,
+ data: (),
+ },
+ ]];
+
+ soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
+
+ // First box will be reference box, second box should be decayed
+ // Implementation must change to have both be decayed
+ assert_eq!(bboxes[0].len(), 2);
+ assert!(bboxes[0][0].confidence == 0.5);
+ assert!(bboxes[0][1].confidence < 0.5);
+ Ok(())
+}