summaryrefslogtreecommitdiff
path: root/candle-onnx/tests/ops.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-onnx/tests/ops.rs')
-rw-r--r--candle-onnx/tests/ops.rs329
1 files changed, 329 insertions, 0 deletions
diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs
index 51ee037e..55d6fb86 100644
--- a/candle-onnx/tests/ops.rs
+++ b/candle-onnx/tests/ops.rs
@@ -3980,3 +3980,332 @@ fn test_lstm() -> Result<()> {
Ok(())
}
+
+#[test]
+fn test_expand_dim_changed() -> Result<()> {
+ // Create a manual graph for the Expand operation
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Expand".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec!["data".to_string(), "new_shape".to_string()],
+ output: vec!["expanded".to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ input: vec![
+ ValueInfoProto {
+ name: "data".to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ },
+ ValueInfoProto {
+ name: "new_shape".to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ },
+ ],
+ output: vec![ValueInfoProto {
+ name: "expanded".to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ ..GraphProto::default()
+ }));
+
+ // Input tensor with shape [3, 1]
+ let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?;
+
+ // New shape tensor: [2, 1, 6]
+ let new_shape = Tensor::from_vec(vec![2i64, 1, 6], (3,), &Device::Cpu)?;
+
+ // Expected output after expansion
+ let expected = Tensor::from_vec(
+ vec![
+ 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32,
+ 2.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32,
+ 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32,
+ 3.0f32, 3.0f32, 3.0f32,
+ ],
+ (2, 3, 6),
+ &Device::Cpu,
+ )?;
+
+ // Execute the model evaluation
+ let inputs = HashMap::from_iter([
+ ("data".to_string(), data),
+ ("new_shape".to_string(), new_shape),
+ ]);
+ let result = candle_onnx::simple_eval(&manual_graph, inputs)?;
+
+ // Retrieve and compare the result
+ let expanded = result.get("expanded").expect("Output 'expanded' not found");
+
+ assert_eq!(expanded.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
+
+ Ok(())
+}
+
+fn make_graph_helper(
+ op_name: &str,
+ inputs: &[&str],
+ outputs: &[&str],
+ attribs: Vec<AttributeProto>,
+) -> ModelProto {
+ create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: op_name.to_string(),
+ domain: "".to_string(),
+ attribute: attribs,
+ input: inputs.iter().map(|s| s.to_string()).collect(),
+ output: outputs.iter().map(|s| s.to_string()).collect(),
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ input: inputs
+ .into_iter()
+ .map(|name| ValueInfoProto {
+ name: name.to_string(),
+ ..ValueInfoProto::default()
+ })
+ .collect(),
+ output: outputs
+ .into_iter()
+ .map(|name| ValueInfoProto {
+ name: name.to_string(),
+ ..ValueInfoProto::default()
+ })
+ .collect(),
+ ..GraphProto::default()
+ }))
+}
+
+#[test]
+fn test_expand_dim_unchanged() -> Result<()> {
+ // Create a manual graph for the Expand operation
+ let manual_graph = make_graph_helper("Expand", &["data", "new_shape"], &["expanded"], vec![]);
+
+ // Input tensor with shape [3, 1] and dtype f32
+ let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?;
+
+ // New shape tensor: [3, 4]
+ let new_shape = Tensor::from_vec(vec![3i64, 4], (2,), &Device::Cpu)?;
+
+ // Expected output after expansion, dtype f32
+ let expected = Tensor::from_vec(
+ vec![
+ 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32,
+ 3.0f32,
+ ],
+ (3, 4),
+ &Device::Cpu,
+ )?;
+
+ // Execute the model evaluation
+ let inputs = HashMap::from_iter([
+ ("data".to_string(), data),
+ ("new_shape".to_string(), new_shape),
+ ]);
+ let result = candle_onnx::simple_eval(&manual_graph, inputs)?;
+
+ // Retrieve and compare the result
+ let expanded = result.get("expanded").expect("Output 'expanded' not found");
+ assert_eq!(expanded.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
+
+ Ok(())
+}
+
+fn make_split_graph_helper(inputs: &[&str], outputs: &[&str], axis: i64) -> ModelProto {
+ let attribs = vec![AttributeProto {
+ name: "axis".to_string(),
+ r#type: AttributeType::Int.into(),
+ i: axis,
+ ..AttributeProto::default()
+ }];
+
+ make_graph_helper("Split", inputs, outputs, attribs)
+}
+
+#[test]
+fn test_split_equal_parts_1d_opset13() -> Result<()> {
+ let input = Tensor::from_vec(
+ vec![1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32],
+ (6,),
+ &Device::Cpu,
+ )?;
+ let mut inputs = HashMap::new();
+ inputs.insert("input".to_string(), input);
+
+ {
+ let manual_graph =
+ make_split_graph_helper(&["input"], &["output_1", "output_2", "output_3"], 0);
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs.clone())?;
+ assert_eq!(eval.len(), 3);
+
+ let out1 = eval.get("output_1").expect("Output 'output_1' not found");
+ let out2 = eval.get("output_2").expect("Output 'output_2' not found");
+ let out3 = eval.get("output_3").expect("Output 'output_3' not found");
+
+ assert_eq!(out1.to_vec1::<f32>()?, vec![1.0f32, 2.0f32]);
+ assert_eq!(out2.to_vec1::<f32>()?, vec![3.0f32, 4.0f32]);
+ assert_eq!(out3.to_vec1::<f32>()?, vec![5.0f32, 6.0f32]);
+ }
+
+ {
+ let splits = Tensor::from_vec(vec![2i64, 4], (2,), &Device::Cpu)?;
+ inputs.insert("split".to_string(), splits);
+
+ let manual_graph =
+ make_split_graph_helper(&["input", "split"], &["output_1", "output_2"], 0);
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 2);
+
+ let out1 = eval.get("output_1").expect("Output 'output_1' not found");
+ let out2 = eval.get("output_2").expect("Output 'output_2' not found");
+
+ assert_eq!(out1.to_vec1::<f32>()?, vec![1.0f32, 2.0f32]);
+ assert_eq!(out2.to_vec1::<f32>()?, vec![3.0f32, 4.0f32, 5.0f32, 6.0f32]);
+ }
+ Ok(())
+}
+
+fn make_reduce_sum_graph_helper(
+ inputs: &[&str],
+ outputs: &[&str],
+ keepdims: Option<i64>,
+ noop_with_empty_axes: Option<i64>,
+) -> ModelProto {
+ let mut attribs = vec![];
+ if let Some(keepdims) = keepdims {
+ attribs.push(AttributeProto {
+ name: "keepdims".to_string(),
+ r#type: AttributeType::Int.into(),
+ i: keepdims,
+ ..AttributeProto::default()
+ });
+ }
+ if let Some(noop_with_empty_axes) = noop_with_empty_axes {
+ attribs.push(AttributeProto {
+ name: "noop_with_empty_axes".to_string(),
+ r#type: AttributeType::Ints.into(),
+ i: noop_with_empty_axes,
+ ..AttributeProto::default()
+ });
+ }
+ make_graph_helper("ReduceSum", inputs, outputs, attribs)
+}
+
+#[test]
+fn test_reduce_sum_default_axes_keepdims() -> Result<()> {
+ let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(1), None);
+
+ // Test with example data
+ {
+ let data = Tensor::from_vec(
+ vec![
+ 1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
+ ],
+ (3, 2, 2),
+ &Device::Cpu,
+ )?;
+ // let axes = Tensor::from_vec(Vec::<i64>::new(), (0,), &Device::Cpu)?;
+
+ let mut inputs = HashMap::new();
+ inputs.insert("data".to_string(), data);
+ // inputs.insert("axes".to_string(), axes);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let reduced = eval.get("reduced").expect("Output 'reduced' not found");
+ let expected = Tensor::from_vec(vec![78.0f32], (1, 1, 1), &Device::Cpu)?;
+
+ assert_eq!(reduced.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
+ }
+
+ {
+ let data = Tensor::from_vec(
+ vec![
+ -5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6,
+ ],
+ (3, 2, 2),
+ &Device::Cpu,
+ )?;
+
+ let mut inputs = HashMap::new();
+ inputs.insert("data".to_string(), data.clone());
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let reduced = eval.get("reduced").expect("Output 'reduced' not found");
+ let expected = data.sum_all()?.reshape((1, 1, 1))?;
+
+ assert_eq!(reduced.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
+ }
+
+ Ok(())
+}
+
+#[test]
+fn test_reduce_sum_do_not_keep_dims() -> Result<()> {
+ let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(0), None);
+
+ // Test with example data
+ {
+ let data = Tensor::from_vec(
+ vec![
+ 1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
+ ],
+ (3, 2, 2),
+ &Device::Cpu,
+ )?;
+ let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?;
+
+ let mut inputs = HashMap::new();
+ inputs.insert("data".to_string(), data);
+ inputs.insert("axes".to_string(), axes);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let reduced = eval.get("reduced").expect("Output 'reduced' not found");
+ let expected = Tensor::from_vec(
+ vec![4.0f32, 6.0, 12.0, 14.0, 20.0, 22.0],
+ (3, 2),
+ &Device::Cpu,
+ )?;
+
+ assert_eq!(reduced.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
+ }
+
+ // Test with random data
+ {
+ let shape = (3, 2, 2);
+ let data = Tensor::from_vec(
+ vec![
+ -5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6,
+ ],
+ (3, 2, 2),
+ &Device::Cpu,
+ )?;
+ let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?;
+
+ let mut inputs = HashMap::new();
+ inputs.insert("data".to_string(), data.clone());
+ inputs.insert("axes".to_string(), axes);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let reduced = eval.get("reduced").expect("Output 'reduced' not found");
+
+ // Calculate expected result
+ let expected = data.sum(1)?;
+
+ assert_eq!(reduced.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
+ }
+
+ Ok(())
+}