summaryrefslogtreecommitdiff
path: root/candle-onnx/tests/ops.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-onnx/tests/ops.rs')
-rw-r--r--candle-onnx/tests/ops.rs585
1 files changed, 583 insertions, 2 deletions
diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs
index fda76ec2..9b18170a 100644
--- a/candle-onnx/tests/ops.rs
+++ b/candle-onnx/tests/ops.rs
@@ -4,12 +4,16 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
-use candle::{Device, NdArray, Result, Tensor};
+use candle::{DType, Device, NdArray, Result, Tensor};
+use candle_onnx::onnx;
+use candle_onnx::onnx::attribute_proto::AttributeType;
+use candle_onnx::onnx::tensor_proto::DataType;
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
use std::collections::HashMap;
const INPUT_X: &str = "x";
const INPUT_Y: &str = "y";
+const INPUT_A: &str = "a";
const OUTPUT_Z: &str = "z";
fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto {
@@ -820,7 +824,137 @@ fn test_flatten_operation() -> Result<()> {
// #[test]
// "ConstantOfShape"
-// #[test]
+#[test]
+fn test_constant_of_shape() -> Result<()> {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
+ test(&[4i64, 3, 2], Some(1.), &[1., 1., 1.])?;
+
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
+ test(&[0.], Some(0i64), &[0i64])?;
+
+ // "value" defaults to 0 f32
+ test(&[1i64, 2, 3, 4], None as Option<i64>, &[0., 0., 0., 0.])?;
+
+ fn test(
+ input: impl NdArray,
+ value: Option<impl NdArray>,
+ expected: impl NdArray,
+ ) -> Result<()> {
+ let mut attribute = vec![];
+
+ if let Some(value) = value {
+ let tensor = Tensor::new(value, &Device::Cpu)?;
+
+ let (value, data_type) = match tensor.dtype() {
+ DType::U8 => (
+ tensor.to_vec0::<u8>()?.to_le_bytes().to_vec(),
+ DataType::Uint8,
+ ),
+ DType::U32 => (
+ tensor.to_vec0::<u32>()?.to_le_bytes().to_vec(),
+ DataType::Uint32,
+ ),
+ DType::I64 => (
+ tensor.to_vec0::<i64>()?.to_le_bytes().to_vec(),
+ DataType::Int64,
+ ),
+ DType::F32 => (
+ tensor.to_vec0::<f32>()?.to_le_bytes().to_vec(),
+ DataType::Float,
+ ),
+ DType::F64 => (
+ tensor.to_vec0::<f64>()?.to_le_bytes().to_vec(),
+ DataType::Double,
+ ),
+ _ => panic!("unsupported DType in test"),
+ };
+ let tensor = onnx::TensorProto {
+ data_type: data_type.into(),
+ dims: tensor.dims().iter().map(|v| *v as i64).collect(),
+ raw_data: value,
+ segment: None,
+ float_data: vec![],
+ int32_data: vec![],
+ string_data: vec![],
+ int64_data: vec![],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ external_data: vec![],
+ data_location: 0,
+ double_data: vec![],
+ uint64_data: vec![],
+ };
+
+ attribute.push(AttributeProto {
+ name: "value".to_string(),
+ ref_attr_name: "value".to_string(),
+ i: 0,
+ doc_string: "value".to_string(),
+ r#type: AttributeType::Tensor.into(),
+ f: 0.0,
+ s: vec![],
+ t: Some(tensor),
+ g: None,
+ sparse_tensor: None,
+ tp: None,
+ floats: vec![],
+ ints: vec![],
+ strings: vec![],
+ tensors: vec![],
+ graphs: vec![],
+ sparse_tensors: vec![],
+ type_protos: vec![],
+ })
+ }
+
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "ConstantOfShape".to_string(),
+ domain: "".to_string(),
+ attribute,
+ input: vec![INPUT_X.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(input, &Device::Cpu)?);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval
+ .get(OUTPUT_Z)
+ .expect("Output 'z' not found")
+ .to_dtype(DType::F64)?;
+
+ let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
+ match expected.dims().len() {
+ 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
+ 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
+ 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
+ 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
+ _ => unreachable!(),
+ };
+
+ Ok(())
+ }
+ Ok(())
+}
// "Unsqueeze"
// #[test]
@@ -1639,3 +1773,450 @@ fn test_reduce_mean() -> Result<()> {
Ok(())
}
+
+// "Sqrt"
+#[test]
+fn test_sqrt() -> Result<()> {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-155
+ test(&[1., 4., 9.], &[1., 2., 3.])?;
+
+ fn test(data: impl NdArray, expected: impl NdArray) -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Sqrt".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![INPUT_X.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+
+ let expected = Tensor::new(expected, &Device::Cpu)?;
+ match expected.dims().len() {
+ 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
+ 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
+ 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
+ 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
+ _ => unreachable!(),
+ };
+
+ Ok(())
+ }
+
+ Ok(())
+}
+
+// "Range"
+#[test]
+fn test_range() -> Result<()> {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-113
+ test(1., 5., 2., &[1., 3.])?;
+
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-113
+ test(10i64, 6i64, -3i64, &[10i64, 7i64])?;
+
+ fn test(
+ start: impl NdArray,
+ limit: impl NdArray,
+ delta: impl NdArray,
+ expected: impl NdArray,
+ ) -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Range".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![
+ INPUT_X.to_string(),
+ INPUT_Y.to_string(),
+ INPUT_A.to_string(),
+ ],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(start, &Device::Cpu)?);
+ inputs.insert(INPUT_Y.to_string(), Tensor::new(limit, &Device::Cpu)?);
+ inputs.insert(INPUT_A.to_string(), Tensor::new(delta, &Device::Cpu)?);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval
+ .get(OUTPUT_Z)
+ .expect("Output 'z' not found")
+ .to_dtype(DType::F64)?;
+
+ let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
+ match expected.dims().len() {
+ 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
+ 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
+ 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
+ 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
+ _ => unreachable!(),
+ };
+
+ Ok(())
+ }
+
+ Ok(())
+}
+
+// "Greater"
+#[test]
+fn test_greater() -> Result<()> {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-63
+ test(&[1., 2., 3.], &[3., 2., 1.], &[0u8, 0, 1])?;
+
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-63
+ test(&[1., 2., 3.], 2., &[0u8, 0, 1])?;
+
+ fn test(a: impl NdArray, b: impl NdArray, expected: impl NdArray) -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Greater".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?);
+ inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval
+ .get(OUTPUT_Z)
+ .expect("Output 'z' not found")
+ .to_dtype(DType::F64)?;
+
+ let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
+ match expected.dims().len() {
+ 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
+ 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
+ 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
+ 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
+ _ => unreachable!(),
+ };
+
+ Ok(())
+ }
+
+ Ok(())
+}
+
+// "Less"
+#[test]
+fn test_less() -> Result<()> {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-81
+ test(&[1., 2., 3.], &[3., 2., 1.], &[1u8, 0, 0])?;
+
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-81
+ test(&[1., 2., 3.], 2., &[1u8, 0, 0])?;
+
+ fn test(a: impl NdArray, b: impl NdArray, expected: impl NdArray) -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Less".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?);
+ inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval
+ .get(OUTPUT_Z)
+ .expect("Output 'z' not found")
+ .to_dtype(DType::F64)?;
+
+ let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
+ match expected.dims().len() {
+ 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
+ 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
+ 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
+ 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
+ _ => unreachable!(),
+ };
+
+ Ok(())
+ }
+
+ Ok(())
+}
+
+// "Log"
+#[test]
+fn test_log() -> Result<()> {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-82
+ test(&[1., 10.], &[0., std::f64::consts::LN_10])?;
+
+ fn test(data: impl NdArray, expected: impl NdArray) -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Log".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![INPUT_X.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+
+ let expected = Tensor::new(expected, &Device::Cpu)?;
+ match expected.dims().len() {
+ 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
+ 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
+ 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
+ 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
+ _ => unreachable!(),
+ };
+
+ Ok(())
+ }
+
+ Ok(())
+}
+
+// "Min"
+#[test]
+fn test_min() -> Result<()> {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-94
+ test(&[3., 2., 1.], &[1., 4., 4.], &[2., 5., 0.], &[1., 2., 0.])?;
+
+ fn test(
+ a: impl NdArray,
+ b: impl NdArray,
+ c: impl NdArray,
+ expected: impl NdArray,
+ ) -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Min".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![
+ INPUT_X.to_string(),
+ INPUT_Y.to_string(),
+ INPUT_A.to_string(),
+ ],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?);
+ inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?);
+ inputs.insert(INPUT_A.to_string(), Tensor::new(c, &Device::Cpu)?);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+
+ let expected = Tensor::new(expected, &Device::Cpu)?;
+ match expected.dims().len() {
+ 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
+ 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
+ 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
+ 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
+ _ => unreachable!(),
+ };
+
+ Ok(())
+ }
+
+ Ok(())
+}
+
+// "Where"
+#[test]
+fn test_where() -> Result<()> {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-173
+ test(
+ &[[1u8, 0], [1, 1]],
+ &[[1i64, 2], [3, 4]],
+ &[[9i64, 8], [7, 6]],
+ &[[1i64, 8], [3, 4]],
+ )?;
+
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-173
+ test(
+ &[[1u8, 0], [1, 1]],
+ &[[1., 2.], [3., 4.]],
+ &[[9., 8.], [7., 6.]],
+ &[[1., 8.], [3., 4.]],
+ )?;
+
+ fn test(
+ condition: impl NdArray,
+ x: impl NdArray,
+ y: impl NdArray,
+ expected: impl NdArray,
+ ) -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Where".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![
+ INPUT_X.to_string(),
+ INPUT_Y.to_string(),
+ INPUT_A.to_string(),
+ ],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(condition, &Device::Cpu)?);
+ inputs.insert(INPUT_Y.to_string(), Tensor::new(x, &Device::Cpu)?);
+ inputs.insert(INPUT_A.to_string(), Tensor::new(y, &Device::Cpu)?);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval
+ .get(OUTPUT_Z)
+ .expect("Output 'z' not found")
+ .to_dtype(DType::F64)?;
+
+ let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
+ match expected.dims().len() {
+ 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
+ 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
+ 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
+ 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
+ _ => unreachable!(),
+ };
+
+ Ok(())
+ }
+
+ Ok(())
+}