summaryrefslogtreecommitdiff
path: root/candle-onnx/tests/ops.rs
diff options
context:
space:
mode:
authorSacha Arbonel <sacha.arbonel@hotmail.fr>2024-02-23 11:05:46 +0530
committerGitHub <noreply@github.com>2024-02-23 06:35:46 +0100
commit11ea7aac4d6eb81c6e4d998f58252b7868a34e63 (patch)
tree30009c7a10b646657401f90f0a4be3c711ea0038 /candle-onnx/tests/ops.rs
parent32eb56d6b318a3bdbb516237cf89522cb60bbcf2 (diff)
downloadcandle-11ea7aac4d6eb81c6e4d998f58252b7868a34e63.tar.gz
candle-11ea7aac4d6eb81c6e4d998f58252b7868a34e63.tar.bz2
candle-11ea7aac4d6eb81c6e4d998f58252b7868a34e63.zip
tests (#1724)
Diffstat (limited to 'candle-onnx/tests/ops.rs')
-rw-r--r--candle-onnx/tests/ops.rs483
1 files changed, 473 insertions, 10 deletions
diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs
index 74b5aad2..3d77071e 100644
--- a/candle-onnx/tests/ops.rs
+++ b/candle-onnx/tests/ops.rs
@@ -832,7 +832,53 @@ fn test_flatten_operation() -> Result<()> {
// #[test]
// "Shape"
-// #[test]
+#[test]
+fn test_shape_operation() -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Shape".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![INPUT_X.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![ValueInfoProto {
+ name: INPUT_X.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+ let x = Tensor::from_vec(
+ vec![1.0f32, 2.0f32, 3.0f32, 4.0f32],
+ &[2, 2],
+ &Device::Cpu,
+ )?;
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), x);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+ let results = z.to_vec1::<i64>()?;
+ assert_eq!(results, vec![2, 2]);
+
+ Ok(())
+}
// "Conv"
// #[test]
@@ -841,34 +887,451 @@ fn test_flatten_operation() -> Result<()> {
// #[test]
// "Abs"
-// #[test]
+#[test]
+fn test_abs_operation() -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Abs".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![INPUT_X.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![
+ ValueInfoProto {
+ name: INPUT_X.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ },
+ ValueInfoProto {
+ name: INPUT_Y.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ },
+ ],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+ let x = Tensor::from_vec(
+ vec![-1.0f32, 2.0f32, -3.0f32, 4.0f32],
+ &[2, 2],
+ &Device::Cpu,
+ )?;
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), x);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+
+ let results = z.to_vec2::<f32>()?;
+
+ assert_eq!(results, vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
+
+ Ok(())
+}
// "Cos"
-// #[test]
+#[test]
+fn test_cos_operation() -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Cos".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![INPUT_X.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![
+ ValueInfoProto {
+ name: INPUT_X.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ },
+ ValueInfoProto {
+ name: INPUT_Y.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ },
+ ],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+ let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), x);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+
+ let results = z.to_vec2::<f32>()?;
+
+ assert_eq!(
+ results,
+ vec![vec![1.0, 0.54030234], vec![-0.41614684, -0.9899925]]
+ );
+
+ Ok(())
+}
// "Sin"
-// #[test]
+#[test]
+fn test_sin_operation() -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Sin".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![INPUT_X.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![
+ ValueInfoProto {
+ name: INPUT_X.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ },
+ ValueInfoProto {
+ name: INPUT_Y.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ },
+ ],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+ let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), x);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+
+ let results = z.to_vec2::<f32>()?;
+
+ assert_eq!(results, vec![vec![0.0, 0.841471], vec![0.9092974, 0.14112]]);
+
+ Ok(())
+}
// "Neg"
-// #[test]
+#[test]
+fn test_neg_operation() -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Neg".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![INPUT_X.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![
+ ValueInfoProto {
+ name: INPUT_X.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ },
+ ValueInfoProto {
+ name: INPUT_Y.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ },
+ ],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+ let x = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu)?;
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), x);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+
+ let results = z.to_vec2::<f32>()?;
+
+ assert_eq!(results, vec![vec![-1.0, -2.0], vec![-3.0, -4.0]]);
+
+ Ok(())
+}
// "Erf"
// #[test]
// "Tanh"
-// #[test]
+#[test]
+fn test_tanh_operation() -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Tanh".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![INPUT_X.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![
+ ValueInfoProto {
+ name: INPUT_X.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ },
+ ValueInfoProto {
+ name: INPUT_Y.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ },
+ ],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+ let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), x);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+
+ let results = z.to_vec2::<f32>()?;
+
+ assert_eq!(
+ results,
+ vec![vec![0.0, 0.7615942], vec![0.9640276, 0.9950548]]
+ );
+
+ Ok(())
+}
// "Sigmoid"
-// #[test]
+#[test]
+fn test_sigmoid_operation() -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Sigmoid".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![INPUT_X.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![
+ ValueInfoProto {
+ name: INPUT_X.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ },
+ ValueInfoProto {
+ name: INPUT_Y.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ },
+ ],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+ let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), x);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+
+ let results = z.to_vec2::<f32>()?;
+
+ assert_eq!(
+ results,
+ vec![vec![0.5, 0.7310586], vec![0.880797, 0.95257413]]
+ );
+
+ Ok(())
+}
// "Gelu"
-// #[test]
+#[test]
+fn test_gelu_operation() -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Gelu".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![INPUT_X.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![
+ ValueInfoProto {
+ name: INPUT_X.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ },
+ ValueInfoProto {
+ name: INPUT_Y.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ },
+ ],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+ let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), x);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+
+ let results = z.to_vec2::<f32>()?;
+
+ assert_eq!(
+ results,
+ vec![vec![0.0, 0.8413448], vec![1.9544997, 2.9959502]]
+ );
+
+ Ok(())
+}
// "Relu"
-// #[test]
+#[test]
+fn test_relu_operation() -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Relu".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![INPUT_X.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![ValueInfoProto {
+ name: INPUT_X.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+ let x = Tensor::from_vec(vec![-1.0f32, 1.0f32, -2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), x);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+
+ let results = z.to_vec2::<f32>()?;
+
+ assert_eq!(results, vec![vec![0.0, 1.0], vec![0.0, 3.0]]);
+
+ Ok(())
+}
// "Constant"
// #[test]
// "Cast"
-// #[test]
+// #[test] \ No newline at end of file