summaryrefslogtreecommitdiff
path: root/candle-onnx/tests
diff options
context:
space:
mode:
authorB1rtek <53182944+B1rtek@users.noreply.github.com>2024-06-04 22:49:02 +0200
committerGitHub <noreply@github.com>2024-06-04 22:49:02 +0200
commitcb180eb23a6f563a241834ceac03f19f12108545 (patch)
tree49376bb707c8785db7b1027cc80015503241a472 /candle-onnx/tests
parent9182c828e6c727d149075d7cee8dbcb6d5a5f884 (diff)
downloadcandle-cb180eb23a6f563a241834ceac03f19f12108545.tar.gz
candle-cb180eb23a6f563a241834ceac03f19f12108545.tar.bz2
candle-cb180eb23a6f563a241834ceac03f19f12108545.zip
ONNX: add ArgMin, ArgMax and LeakyRelu (#2246)
* Add basic RandomUniform implementation * Use is_some to check if seed is present * Added Exp operator implementation * Added ArgMin operator implementation * Added tests for ArgMin * ArgMin now returns a tensor with i64 * Added tests from pytorch examples * Added ArgMax operator implementation * Added tests for ArgMax * Added LeakyRelu implementation * Added a test for LeakyRelu * Typo fix * Fix a weird automatic RustRover change --------- Co-authored-by: Mateusz Okulus <mmokulus@gmail.com>
Diffstat (limited to 'candle-onnx/tests')
-rw-r--r--candle-onnx/tests/ops.rs470
1 files changed, 470 insertions, 0 deletions
diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs
index f58aeccf..ffafd7a7 100644
--- a/candle-onnx/tests/ops.rs
+++ b/candle-onnx/tests/ops.rs
@@ -2708,3 +2708,473 @@ fn test_ceil() -> Result<()> {
Ok(())
}
+
+// "ArgMin"
+#[test]
+fn test_argmin() -> Result<()> {
+ // tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-7
+ // default_axes_keepdims
+ test(
+ &[
+ [2u32, 1u32],
+ [3u32, 10u32]
+ ],
+ None,
+ Some(1),
+ None,
+ &[
+ [0i64, 0i64],
+ ],
+ )?;
+ // keepdims
+ test(
+ &[
+ [2u32, 1u32],
+ [3u32, 10u32]
+ ],
+ Some(1),
+ Some(1),
+ None,
+ &[
+ [1i64],
+ [0i64]
+ ],
+ )?;
+ // // negative_axis_keepdims
+ test(
+ &[
+ [2u32, 1u32],
+ [3u32, 10u32]
+ ],
+ Some(-1),
+ Some(1),
+ None,
+ &[
+ [1i64],
+ [0i64]
+ ],
+ )?;
+ // no_keepdims
+ test(
+ &[
+ [2u32, 1u32],
+ [3u32, 10u32]
+ ],
+ None,
+ Some(0),
+ None,
+ &[0i64, 0i64],
+ )?;
+ // tests from https://pytorch.org/docs/stable/generated/torch.argmin.html#torch.argmin
+ test(
+ &[
+ [0.1139, 0.2254, -0.1381, 0.3687],
+ [1.0100, -1.1975, -0.0102, -0.4732],
+ [-0.9240, 0.1207, -0.7506, -1.0213],
+ [1.7809, -1.2960, 0.9384, 0.1438]
+ ],
+ Some(1),
+ Some(0),
+ None,
+ &[2i64, 1i64, 3i64, 1i64],
+ )?;
+ test(
+ &[
+ [0.1139, 0.2254, -0.1381, 0.3687],
+ [1.0100, -1.1975, -0.0102, -0.4732],
+ [-0.9240, 0.1207, -0.7506, -1.0213],
+ [1.7809, -1.2960, 0.9384, 0.1438]
+ ],
+ Some(1),
+ None,
+ None,
+ &[[2i64], [1i64], [3i64], [1i64]],
+ )?;
+ fn test(data: impl NdArray, axis: Option<i64>, keepdims: Option<i64>, select_last_index: Option<i64>, expected: impl NdArray) -> Result<()> {
+ let att_axis = AttributeProto {
+ name: "axis".to_string(),
+ ref_attr_name: "axis".to_string(),
+ i: axis.unwrap_or(0),
+ doc_string: "axis".to_string(),
+ r#type: 2, // INT
+ f: 0.0,
+ s: vec![],
+ t: None,
+ g: None,
+ sparse_tensor: None,
+ tp: None,
+ floats: vec![],
+ ints: vec![],
+ strings: vec![],
+ tensors: vec![],
+ graphs: vec![],
+ sparse_tensors: vec![],
+ type_protos: vec![],
+ };
+ let att_keepdims = AttributeProto {
+ name: "keepdims".to_string(),
+ ref_attr_name: "keepdims".to_string(),
+ i: keepdims.unwrap_or(1),
+ doc_string: "keepdims".to_string(),
+ r#type: 2, // INT
+ f: 0.0,
+ s: vec![],
+ t: None,
+ g: None,
+ sparse_tensor: None,
+ tp: None,
+ floats: vec![],
+ ints: vec![],
+ strings: vec![],
+ tensors: vec![],
+ graphs: vec![],
+ sparse_tensors: vec![],
+ type_protos: vec![],
+ };
+ let att_select_last_index = AttributeProto {
+ name: "select_last_index".to_string(),
+ ref_attr_name: "select_last_index".to_string(),
+ i: select_last_index.unwrap_or(0),
+ doc_string: "select_last_index".to_string(),
+ r#type: 2, // INT
+ f: 0.0,
+ s: vec![],
+ t: None,
+ g: None,
+ sparse_tensor: None,
+ tp: None,
+ floats: vec![],
+ ints: vec![],
+ strings: vec![],
+ tensors: vec![],
+ graphs: vec![],
+ sparse_tensors: vec![],
+ type_protos: vec![],
+ };
+ let attrs = {
+ let mut mut_attrs = vec![];
+ if axis.is_some() {
+ mut_attrs.push(att_axis);
+ }
+ if keepdims.is_some() {
+ mut_attrs.push(att_keepdims);
+ }
+ if select_last_index.is_some() {
+ mut_attrs.push(att_select_last_index);
+ }
+ mut_attrs
+ };
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "ArgMin".to_string(),
+ domain: "".to_string(),
+ attribute: attrs,
+ input: vec![INPUT_X.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+
+ let expected = Tensor::new(expected, &Device::Cpu)?;
+ match expected.dims().len() {
+ 1 => assert_eq!(z.to_vec1::<i64>()?, expected.to_vec1::<i64>()?),
+ 2 => assert_eq!(z.to_vec2::<i64>()?, expected.to_vec2::<i64>()?),
+ _ => unreachable!(),
+ };
+
+ Ok(())
+ }
+
+ Ok(())
+}
+
+// "ArgMax"
+#[test]
+fn test_argmax() -> Result<()> {
+ // tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-6
+ // default_axes_keepdims
+ test(
+ &[
+ [2u32, 1u32],
+ [3u32, 10u32]
+ ],
+ None,
+ Some(1),
+ None,
+ &[
+ [1i64, 1i64],
+ ],
+ )?;
+ // keepdims
+ test(
+ &[
+ [2u32, 1u32],
+ [3u32, 10u32]
+ ],
+ Some(1),
+ Some(1),
+ None,
+ &[
+ [0i64],
+ [1i64]
+ ],
+ )?;
+ // // negative_axis_keepdims
+ test(
+ &[
+ [2u32, 1u32],
+ [3u32, 10u32]
+ ],
+ Some(-1),
+ Some(1),
+ None,
+ &[
+ [0i64],
+ [1i64]
+ ],
+ )?;
+ // no_keepdims
+ test(
+ &[
+ [2u32, 1u32],
+ [3u32, 10u32]
+ ],
+ None,
+ Some(0),
+ None,
+ &[1i64, 1i64],
+ )?;
+ // tests from https://pytorch.org/docs/stable/generated/torch.argmax.html
+ test(
+ &[
+ [1.3398, 0.2663, -0.2686, 0.2450],
+ [-0.7401, -0.8805, -0.3402, -1.1936],
+ [0.4907, -1.3948, -1.0691, -0.3132],
+ [-1.6092, 0.5419, -0.2993, 0.3195]
+ ],
+ Some(1),
+ Some(0),
+ None,
+ &[0i64, 2i64, 0i64, 1i64],
+ )?;
+ test(
+ &[
+ [1.3398, 0.2663, -0.2686, 0.2450],
+ [-0.7401, -0.8805, -0.3402, -1.1936],
+ [0.4907, -1.3948, -1.0691, -0.3132],
+ [-1.6092, 0.5419, -0.2993, 0.3195]
+ ],
+ Some(1),
+ None,
+ None,
+ &[[0i64], [2i64], [0i64], [1i64]],
+ )?;
+ fn test(data: impl NdArray, axis: Option<i64>, keepdims: Option<i64>, select_last_index: Option<i64>, expected: impl NdArray) -> Result<()> {
+ let att_axis = AttributeProto {
+ name: "axis".to_string(),
+ ref_attr_name: "axis".to_string(),
+ i: axis.unwrap_or(0),
+ doc_string: "axis".to_string(),
+ r#type: 2, // INT
+ f: 0.0,
+ s: vec![],
+ t: None,
+ g: None,
+ sparse_tensor: None,
+ tp: None,
+ floats: vec![],
+ ints: vec![],
+ strings: vec![],
+ tensors: vec![],
+ graphs: vec![],
+ sparse_tensors: vec![],
+ type_protos: vec![],
+ };
+ let att_keepdims = AttributeProto {
+ name: "keepdims".to_string(),
+ ref_attr_name: "keepdims".to_string(),
+ i: keepdims.unwrap_or(1),
+ doc_string: "keepdims".to_string(),
+ r#type: 2, // INT
+ f: 0.0,
+ s: vec![],
+ t: None,
+ g: None,
+ sparse_tensor: None,
+ tp: None,
+ floats: vec![],
+ ints: vec![],
+ strings: vec![],
+ tensors: vec![],
+ graphs: vec![],
+ sparse_tensors: vec![],
+ type_protos: vec![],
+ };
+ let att_select_last_index = AttributeProto {
+ name: "select_last_index".to_string(),
+ ref_attr_name: "select_last_index".to_string(),
+ i: select_last_index.unwrap_or(0),
+ doc_string: "select_last_index".to_string(),
+ r#type: 2, // INT
+ f: 0.0,
+ s: vec![],
+ t: None,
+ g: None,
+ sparse_tensor: None,
+ tp: None,
+ floats: vec![],
+ ints: vec![],
+ strings: vec![],
+ tensors: vec![],
+ graphs: vec![],
+ sparse_tensors: vec![],
+ type_protos: vec![],
+ };
+ let attrs = {
+ let mut mut_attrs = vec![];
+ if axis.is_some() {
+ mut_attrs.push(att_axis);
+ }
+ if keepdims.is_some() {
+ mut_attrs.push(att_keepdims);
+ }
+ if select_last_index.is_some() {
+ mut_attrs.push(att_select_last_index);
+ }
+ mut_attrs
+ };
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "ArgMax".to_string(),
+ domain: "".to_string(),
+ attribute: attrs,
+ input: vec![INPUT_X.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+
+ let expected = Tensor::new(expected, &Device::Cpu)?;
+ match expected.dims().len() {
+ 1 => assert_eq!(z.to_vec1::<i64>()?, expected.to_vec1::<i64>()?),
+ 2 => assert_eq!(z.to_vec2::<i64>()?, expected.to_vec2::<i64>()?),
+ _ => unreachable!(),
+ };
+
+ Ok(())
+ }
+
+ Ok(())
+}
+
+// "LeakyRelu"
+#[test]
+fn test_leakyrelu() -> Result<()> {
+ // tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-80
+ // leakyrelu
+ test(
+ &[-1.0, 0.0, 1.0],
+ Some(0.1),
+ &[-0.1, 0.0, 1.0]
+ )?;
+ fn test(data: impl NdArray, alpha: Option<f32>, expected: impl NdArray) -> Result<()> {
+ let att_alpha = AttributeProto {
+ name: "alpha".to_string(),
+ ref_attr_name: "alpha".to_string(),
+ i: 0,
+ doc_string: "alpha".to_string(),
+ r#type: 1, // FLOAT
+ f: alpha.unwrap_or(0.01),
+ s: vec![],
+ t: None,
+ g: None,
+ sparse_tensor: None,
+ tp: None,
+ floats: vec![],
+ ints: vec![],
+ strings: vec![],
+ tensors: vec![],
+ graphs: vec![],
+ sparse_tensors: vec![],
+ type_protos: vec![],
+ };
+ let attrs = {
+ let mut mut_attrs = vec![];
+ if alpha.is_some() {
+ mut_attrs.push(att_alpha);
+ }
+ mut_attrs
+ };
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "LeakyRelu".to_string(),
+ domain: "".to_string(),
+ attribute: attrs,
+ input: vec![INPUT_X.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+
+ let expected = Tensor::new(expected, &Device::Cpu)?;
+ for both in z.to_vec1::<f64>()?.iter().zip(expected.to_vec1::<f64>()?.iter()) {
+ let (act, exp) = both;
+ assert!(f64::abs(act - exp) < f32::EPSILON.into());
+ }
+
+ Ok(())
+ }
+
+ Ok(())
+}