diff options
author | Matthew O'Malley-Nichols <91226873+onichmath@users.noreply.github.com> | 2024-08-09 22:57:52 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-10 07:57:52 +0200 |
commit | 14db029494171268600914995369e0962f48c29a (patch) | |
tree | 1ba49badf6f4ae8a5a57f205315232476f3577de /candle-transformers/tests | |
parent | 6e6c1c99b09de707d1f0aa6839be70202b278a57 (diff) | |
download | candle-14db029494171268600914995369e0962f48c29a.tar.gz candle-14db029494171268600914995369e0962f48c29a.tar.bz2 candle-14db029494171268600914995369e0962f48c29a.zip |
Soft Non-Maximum Suppression (#2400)
* Soft NMS with thresholds
* NMS Test
* Soft nms w/ boxes removed below threshold
* Soft nms test
* No longer removing bounding boxes to fit Soft-NMS focus
* Initialize confidence
* Added comments
* Refactored out updating based on IOU/sigma
* Score_threshold -> confidence_threshold for clarity
* Remove bboxes below confidence threshold
* Softnms basic functionality test
* Softnms confidence decay test
* Softnms confidence threshold test
* Softnms no overlapping bbox test
* Testing confidence after no overlap test
* Single bbox and no bbox tests
* Signify test completion
* Handling result of test functions
* Checking all pairs of bboxes instead of a forward pass
* Equal confidence overlap test
* Clarified tests for implementation
* No longer dropping boxes, just setting to 0.0
* Formatted w/ cargo
Diffstat (limited to 'candle-transformers/tests')
-rw-r--r-- | candle-transformers/tests/nms_tests.rs | 222 |
1 files changed, 222 insertions, 0 deletions
diff --git a/candle-transformers/tests/nms_tests.rs b/candle-transformers/tests/nms_tests.rs new file mode 100644 index 00000000..d70f6fdf --- /dev/null +++ b/candle-transformers/tests/nms_tests.rs @@ -0,0 +1,222 @@ +use candle::Result; +use candle_transformers::object_detection::{ + non_maximum_suppression, soft_non_maximum_suppression, Bbox, +}; + +#[test] +fn nms_basic() -> Result<()> { + // Boxes based upon https://thepythoncode.com/article/non-maximum-suppression-using-opencv-in-python + let mut bboxes = vec![vec![ + Bbox { + xmin: 245.0, + ymin: 305.0, + xmax: 575.0, + ymax: 490.0, + confidence: 0.9, + data: (), + }, // Box 1 + Bbox { + xmin: 235.0, + ymin: 300.0, + xmax: 485.0, + ymax: 515.0, + confidence: 0.8, + data: (), + }, // Box 2 + Bbox { + xmin: 305.0, + ymin: 270.0, + xmax: 540.0, + ymax: 500.0, + confidence: 0.6, + data: (), + }, // Box 3 + ]]; + + non_maximum_suppression(&mut bboxes, 0.5); + let bboxes = bboxes.into_iter().next().unwrap(); + assert_eq!(bboxes.len(), 1); + assert_eq!(bboxes[0].confidence, 0.9); + + Ok(()) +} + +#[test] +fn softnms_basic_functionality() -> Result<()> { + let mut bboxes = vec![vec![ + Bbox { + xmin: 0.0, + ymin: 0.0, + xmax: 1.0, + ymax: 1.0, + confidence: 0.5, + data: (), + }, + Bbox { + xmin: 0.1, + ymin: 0.1, + xmax: 1.1, + ymax: 1.1, + confidence: 0.9, + data: (), + }, + Bbox { + xmin: 0.2, + ymin: 0.2, + xmax: 1.2, + ymax: 1.2, + confidence: 0.6, + data: (), + }, + ]]; + + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + + // Should decay boxes following highest confidence box + assert!(bboxes[0][0].confidence == 0.9); + assert!(bboxes[0][1].confidence < 0.5); + assert!(bboxes[0][2].confidence < 0.6); + Ok(()) +} + +#[test] +fn softnms_confidence_decay() -> Result<()> { + let mut bboxes = vec![vec![ + Bbox { + xmin: 0.0, + ymin: 0.0, + xmax: 1.0, + ymax: 1.0, + confidence: 0.9, + data: (), + }, // Reference box + Bbox { + xmin: 0.1, + ymin: 0.1, + xmax: 1.1, + ymax: 1.1, + confidence: 0.8, + data: (), + }, // Overlapping box + ]]; + + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + + // Check that confidence of the overlapping box is decayed + assert!(bboxes[0][0].confidence == 0.9); + assert!(bboxes[0][1].confidence < 0.8); + Ok(()) +} + +#[test] +fn softnms_confidence_threshold() -> Result<()> { + let mut bboxes = vec![vec![ + Bbox { + xmin: 0.0, + ymin: 0.0, + xmax: 1.0, + ymax: 1.0, + confidence: 0.9, + data: (), + }, + Bbox { + xmin: 0.1, + ymin: 0.1, + xmax: 1.1, + ymax: 1.1, + confidence: 0.05, + data: (), + }, + ]]; + + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + + // Box with confidence below the threshold should be removed + assert_eq!(bboxes[0].len(), 2); + assert_eq!(bboxes[0][0].confidence, 0.9); + assert_eq!(bboxes[0][1].confidence, 0.00); + Ok(()) +} + +#[test] +fn softnms_no_overlap() -> Result<()> { + let mut bboxes = vec![vec![ + Bbox { + xmin: 0.0, + ymin: 0.0, + xmax: 1.0, + ymax: 1.0, + confidence: 0.9, + data: (), + }, + Bbox { + xmin: 2.0, + ymin: 2.0, + xmax: 3.0, + ymax: 3.0, + confidence: 0.8, + data: (), + }, + ]]; + + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + + // Both boxes should remain as they do not significantly overlap + assert_eq!(bboxes[0].len(), 2); + assert_eq!(bboxes[0][0].confidence, 0.9); + assert_eq!(bboxes[0][1].confidence, 0.8); + Ok(()) +} +#[test] +fn softnms_no_bbox() -> Result<()> { + let mut bboxes: Vec<Vec<Bbox<()>>> = vec![]; + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + assert!(bboxes.is_empty()); + Ok(()) +} + +#[test] +fn softnms_single_bbox() -> Result<()> { + let mut bboxes = vec![vec![Bbox { + xmin: 0.0, + ymin: 0.0, + xmax: 1.0, + ymax: 1.0, + confidence: 0.9, + data: (), + }]]; + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + assert_eq!(bboxes[0].len(), 1); + Ok(()) +} + +#[test] +fn softnms_equal_confidence_overlap() -> Result<()> { + let mut bboxes = vec![vec![ + Bbox { + xmin: 0.0, + ymin: 0.0, + xmax: 1.0, + ymax: 1.0, + confidence: 0.5, + data: (), + }, + Bbox { + xmin: 0.1, + ymin: 0.1, + xmax: 1.1, + ymax: 1.1, + confidence: 0.5, + data: (), + }, + ]]; + + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + + // First box will be reference box, second box should be decayed + // Implementation must change to have both be decayed + assert_eq!(bboxes[0].len(), 2); + assert!(bboxes[0][0].confidence == 0.5); + assert!(bboxes[0][1].confidence < 0.5); + Ok(()) +} |