diff options
author | Sacha Arbonel <sacha.arbonel@hotmail.fr> | 2024-02-23 11:05:46 +0530 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-23 06:35:46 +0100 |
commit | 11ea7aac4d6eb81c6e4d998f58252b7868a34e63 (patch) | |
tree | 30009c7a10b646657401f90f0a4be3c711ea0038 /candle-onnx/tests/ops.rs | |
parent | 32eb56d6b318a3bdbb516237cf89522cb60bbcf2 (diff) | |
download | candle-11ea7aac4d6eb81c6e4d998f58252b7868a34e63.tar.gz candle-11ea7aac4d6eb81c6e4d998f58252b7868a34e63.tar.bz2 candle-11ea7aac4d6eb81c6e4d998f58252b7868a34e63.zip |
tests (#1724)
Diffstat (limited to 'candle-onnx/tests/ops.rs')
-rw-r--r-- | candle-onnx/tests/ops.rs | 483 |
1 files changed, 473 insertions, 10 deletions
diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index 74b5aad2..3d77071e 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -832,7 +832,53 @@ fn test_flatten_operation() -> Result<()> { // #[test] // "Shape" -// #[test] +#[test] +fn test_shape_operation() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Shape".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![ValueInfoProto { + name: INPUT_X.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let x = Tensor::from_vec( + vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], + &[2, 2], + &Device::Cpu, + )?; + + let mut inputs: HashMap<String, Tensor> = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + let results = z.to_vec1::<i64>()?; + assert_eq!(results, vec![2, 2]); + + Ok(()) +} // "Conv" // #[test] @@ -841,34 +887,451 @@ fn test_flatten_operation() -> Result<()> { // #[test] // "Abs" -// #[test] +#[test] +fn test_abs_operation() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Abs".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![ + ValueInfoProto { + name: INPUT_X.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ValueInfoProto { + name: INPUT_Y.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let x = Tensor::from_vec( + vec![-1.0f32, 2.0f32, -3.0f32, 4.0f32], + &[2, 2], + &Device::Cpu, + )?; + + let mut inputs: HashMap<String, Tensor> = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let results = z.to_vec2::<f32>()?; + + assert_eq!(results, vec![vec![1.0, 2.0], vec![3.0, 4.0]]); + + Ok(()) +} // "Cos" -// #[test] +#[test] +fn test_cos_operation() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Cos".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![ + ValueInfoProto { + name: INPUT_X.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ValueInfoProto { + name: INPUT_Y.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?; + + let mut inputs: HashMap<String, Tensor> = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let results = z.to_vec2::<f32>()?; + + assert_eq!( + results, + vec![vec![1.0, 0.54030234], vec![-0.41614684, -0.9899925]] + ); + + Ok(()) +} // "Sin" -// #[test] +#[test] +fn test_sin_operation() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Sin".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![ + ValueInfoProto { + name: INPUT_X.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ValueInfoProto { + name: INPUT_Y.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?; + + let mut inputs: HashMap<String, Tensor> = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let results = z.to_vec2::<f32>()?; + + assert_eq!(results, vec![vec![0.0, 0.841471], vec![0.9092974, 0.14112]]); + + Ok(()) +} // "Neg" -// #[test] +#[test] +fn test_neg_operation() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Neg".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![ + ValueInfoProto { + name: INPUT_X.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ValueInfoProto { + name: INPUT_Y.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let x = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu)?; + + let mut inputs: HashMap<String, Tensor> = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let results = z.to_vec2::<f32>()?; + + assert_eq!(results, vec![vec![-1.0, -2.0], vec![-3.0, -4.0]]); + + Ok(()) +} // "Erf" // #[test] // "Tanh" -// #[test] +#[test] +fn test_tanh_operation() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Tanh".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![ + ValueInfoProto { + name: INPUT_X.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ValueInfoProto { + name: INPUT_Y.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?; + + let mut inputs: HashMap<String, Tensor> = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let results = z.to_vec2::<f32>()?; + + assert_eq!( + results, + vec![vec![0.0, 0.7615942], vec![0.9640276, 0.9950548]] + ); + + Ok(()) +} // "Sigmoid" -// #[test] +#[test] +fn test_sigmoid_operation() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Sigmoid".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![ + ValueInfoProto { + name: INPUT_X.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ValueInfoProto { + name: INPUT_Y.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?; + + let mut inputs: HashMap<String, Tensor> = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let results = z.to_vec2::<f32>()?; + + assert_eq!( + results, + vec![vec![0.5, 0.7310586], vec![0.880797, 0.95257413]] + ); + + Ok(()) +} // "Gelu" -// #[test] +#[test] +fn test_gelu_operation() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Gelu".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![ + ValueInfoProto { + name: INPUT_X.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ValueInfoProto { + name: INPUT_Y.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?; + + let mut inputs: HashMap<String, Tensor> = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let results = z.to_vec2::<f32>()?; + + assert_eq!( + results, + vec![vec![0.0, 0.8413448], vec![1.9544997, 2.9959502]] + ); + + Ok(()) +} // "Relu" -// #[test] +#[test] +fn test_relu_operation() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Relu".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![ValueInfoProto { + name: INPUT_X.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let x = Tensor::from_vec(vec![-1.0f32, 1.0f32, -2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?; + + let mut inputs: HashMap<String, Tensor> = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let results = z.to_vec2::<f32>()?; + + assert_eq!(results, vec![vec![0.0, 1.0], vec![0.0, 3.0]]); + + Ok(()) +} // "Constant" // #[test] // "Cast" -// #[test] +// #[test]
\ No newline at end of file |