diff options
Diffstat (limited to 'candle-onnx/tests/ops.rs')
-rw-r--r-- | candle-onnx/tests/ops.rs | 329 |
1 files changed, 329 insertions, 0 deletions
diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index 51ee037e..55d6fb86 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -3980,3 +3980,332 @@ fn test_lstm() -> Result<()> { Ok(()) } + +#[test] +fn test_expand_dim_changed() -> Result<()> { + // Create a manual graph for the Expand operation + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Expand".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec!["data".to_string(), "new_shape".to_string()], + output: vec!["expanded".to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + input: vec![ + ValueInfoProto { + name: "data".to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ValueInfoProto { + name: "new_shape".to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ], + output: vec![ValueInfoProto { + name: "expanded".to_string(), + doc_string: "".to_string(), + r#type: None, + }], + ..GraphProto::default() + })); + + // Input tensor with shape [3, 1] + let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?; + + // New shape tensor: [2, 1, 6] + let new_shape = Tensor::from_vec(vec![2i64, 1, 6], (3,), &Device::Cpu)?; + + // Expected output after expansion + let expected = Tensor::from_vec( + vec![ + 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, + 2.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, + 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32, + 3.0f32, 3.0f32, 3.0f32, + ], + (2, 3, 6), + &Device::Cpu, + )?; + + // Execute the model evaluation + let inputs = HashMap::from_iter([ + ("data".to_string(), data), + ("new_shape".to_string(), new_shape), + ]); + let result = candle_onnx::simple_eval(&manual_graph, inputs)?; + + // Retrieve and compare the result + let expanded = result.get("expanded").expect("Output 'expanded' not found"); + + assert_eq!(expanded.to_vec3::<f32>()?, expected.to_vec3::<f32>()?); + + Ok(()) +} + +fn make_graph_helper( + op_name: &str, + inputs: &[&str], + outputs: &[&str], + attribs: Vec<AttributeProto>, +) -> ModelProto { + create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: op_name.to_string(), + domain: "".to_string(), + attribute: attribs, + input: inputs.iter().map(|s| s.to_string()).collect(), + output: outputs.iter().map(|s| s.to_string()).collect(), + name: "".to_string(), + doc_string: "".to_string(), + }], + input: inputs + .into_iter() + .map(|name| ValueInfoProto { + name: name.to_string(), + ..ValueInfoProto::default() + }) + .collect(), + output: outputs + .into_iter() + .map(|name| ValueInfoProto { + name: name.to_string(), + ..ValueInfoProto::default() + }) + .collect(), + ..GraphProto::default() + })) +} + +#[test] +fn test_expand_dim_unchanged() -> Result<()> { + // Create a manual graph for the Expand operation + let manual_graph = make_graph_helper("Expand", &["data", "new_shape"], &["expanded"], vec![]); + + // Input tensor with shape [3, 1] and dtype f32 + let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?; + + // New shape tensor: [3, 4] + let new_shape = Tensor::from_vec(vec![3i64, 4], (2,), &Device::Cpu)?; + + // Expected output after expansion, dtype f32 + let expected = Tensor::from_vec( + vec![ + 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32, + 3.0f32, + ], + (3, 4), + &Device::Cpu, + )?; + + // Execute the model evaluation + let inputs = HashMap::from_iter([ + ("data".to_string(), data), + ("new_shape".to_string(), new_shape), + ]); + let result = candle_onnx::simple_eval(&manual_graph, inputs)?; + + // Retrieve and compare the result + let expanded = result.get("expanded").expect("Output 'expanded' not found"); + assert_eq!(expanded.to_vec2::<f32>()?, expected.to_vec2::<f32>()?); + + Ok(()) +} + +fn make_split_graph_helper(inputs: &[&str], outputs: &[&str], axis: i64) -> ModelProto { + let attribs = vec![AttributeProto { + name: "axis".to_string(), + r#type: AttributeType::Int.into(), + i: axis, + ..AttributeProto::default() + }]; + + make_graph_helper("Split", inputs, outputs, attribs) +} + +#[test] +fn test_split_equal_parts_1d_opset13() -> Result<()> { + let input = Tensor::from_vec( + vec![1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32], + (6,), + &Device::Cpu, + )?; + let mut inputs = HashMap::new(); + inputs.insert("input".to_string(), input); + + { + let manual_graph = + make_split_graph_helper(&["input"], &["output_1", "output_2", "output_3"], 0); + let eval = candle_onnx::simple_eval(&manual_graph, inputs.clone())?; + assert_eq!(eval.len(), 3); + + let out1 = eval.get("output_1").expect("Output 'output_1' not found"); + let out2 = eval.get("output_2").expect("Output 'output_2' not found"); + let out3 = eval.get("output_3").expect("Output 'output_3' not found"); + + assert_eq!(out1.to_vec1::<f32>()?, vec![1.0f32, 2.0f32]); + assert_eq!(out2.to_vec1::<f32>()?, vec![3.0f32, 4.0f32]); + assert_eq!(out3.to_vec1::<f32>()?, vec![5.0f32, 6.0f32]); + } + + { + let splits = Tensor::from_vec(vec![2i64, 4], (2,), &Device::Cpu)?; + inputs.insert("split".to_string(), splits); + + let manual_graph = + make_split_graph_helper(&["input", "split"], &["output_1", "output_2"], 0); + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 2); + + let out1 = eval.get("output_1").expect("Output 'output_1' not found"); + let out2 = eval.get("output_2").expect("Output 'output_2' not found"); + + assert_eq!(out1.to_vec1::<f32>()?, vec![1.0f32, 2.0f32]); + assert_eq!(out2.to_vec1::<f32>()?, vec![3.0f32, 4.0f32, 5.0f32, 6.0f32]); + } + Ok(()) +} + +fn make_reduce_sum_graph_helper( + inputs: &[&str], + outputs: &[&str], + keepdims: Option<i64>, + noop_with_empty_axes: Option<i64>, +) -> ModelProto { + let mut attribs = vec![]; + if let Some(keepdims) = keepdims { + attribs.push(AttributeProto { + name: "keepdims".to_string(), + r#type: AttributeType::Int.into(), + i: keepdims, + ..AttributeProto::default() + }); + } + if let Some(noop_with_empty_axes) = noop_with_empty_axes { + attribs.push(AttributeProto { + name: "noop_with_empty_axes".to_string(), + r#type: AttributeType::Ints.into(), + i: noop_with_empty_axes, + ..AttributeProto::default() + }); + } + make_graph_helper("ReduceSum", inputs, outputs, attribs) +} + +#[test] +fn test_reduce_sum_default_axes_keepdims() -> Result<()> { + let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(1), None); + + // Test with example data + { + let data = Tensor::from_vec( + vec![ + 1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + ], + (3, 2, 2), + &Device::Cpu, + )?; + // let axes = Tensor::from_vec(Vec::<i64>::new(), (0,), &Device::Cpu)?; + + let mut inputs = HashMap::new(); + inputs.insert("data".to_string(), data); + // inputs.insert("axes".to_string(), axes); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let reduced = eval.get("reduced").expect("Output 'reduced' not found"); + let expected = Tensor::from_vec(vec![78.0f32], (1, 1, 1), &Device::Cpu)?; + + assert_eq!(reduced.to_vec3::<f32>()?, expected.to_vec3::<f32>()?); + } + + { + let data = Tensor::from_vec( + vec![ + -5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6, + ], + (3, 2, 2), + &Device::Cpu, + )?; + + let mut inputs = HashMap::new(); + inputs.insert("data".to_string(), data.clone()); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let reduced = eval.get("reduced").expect("Output 'reduced' not found"); + let expected = data.sum_all()?.reshape((1, 1, 1))?; + + assert_eq!(reduced.to_vec3::<f32>()?, expected.to_vec3::<f32>()?); + } + + Ok(()) +} + +#[test] +fn test_reduce_sum_do_not_keep_dims() -> Result<()> { + let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(0), None); + + // Test with example data + { + let data = Tensor::from_vec( + vec![ + 1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + ], + (3, 2, 2), + &Device::Cpu, + )?; + let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?; + + let mut inputs = HashMap::new(); + inputs.insert("data".to_string(), data); + inputs.insert("axes".to_string(), axes); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let reduced = eval.get("reduced").expect("Output 'reduced' not found"); + let expected = Tensor::from_vec( + vec![4.0f32, 6.0, 12.0, 14.0, 20.0, 22.0], + (3, 2), + &Device::Cpu, + )?; + + assert_eq!(reduced.to_vec2::<f32>()?, expected.to_vec2::<f32>()?); + } + + // Test with random data + { + let shape = (3, 2, 2); + let data = Tensor::from_vec( + vec![ + -5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6, + ], + (3, 2, 2), + &Device::Cpu, + )?; + let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?; + + let mut inputs = HashMap::new(); + inputs.insert("data".to_string(), data.clone()); + inputs.insert("axes".to_string(), axes); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let reduced = eval.get("reduced").expect("Output 'reduced' not found"); + + // Calculate expected result + let expected = data.sum(1)?; + + assert_eq!(reduced.to_vec2::<f32>()?, expected.to_vec2::<f32>()?); + } + + Ok(()) +} |