diff options
author | Ionut Mihalcea <ionut.mihalcea@gmail.com> | 2024-11-26 23:10:09 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-11-26 23:10:09 +0100 |
commit | 21c686387cead049aad32e6d1cc494d6c79e46e3 (patch) | |
tree | b055d64d2bb5ac994f4bb26c67f19d3c258a5035 | |
parent | b4deb5c5a9fc6287f7521e6bc2b7f3c2d56510dd (diff) | |
download | candle-21c686387cead049aad32e6d1cc494d6c79e46e3.tar.gz candle-21c686387cead049aad32e6d1cc494d6c79e46e3.tar.bz2 candle-21c686387cead049aad32e6d1cc494d6c79e46e3.zip |
Onnx Support for Sign operation #2641 (#2642)
* Support for Sign operation #2641
* Apply rustfmt.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
-rw-r--r-- | candle-onnx/src/eval.rs | 6 | ||||
-rw-r--r-- | candle-onnx/tests/ops.rs | 41 |
2 files changed, 47 insertions, 0 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 358af7ac..2c60ed2f 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -1944,6 +1944,12 @@ fn simple_eval_( values.insert(node.output[0].clone(), out); } + // https://onnx.ai/onnx/operators/onnx__Sign.html + "Sign" => { + let input = get(&node.input[0])?; + let output = input.sign()?; + values.insert(node.output[0].clone(), output); + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } } diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index a84ba481..3586bfbd 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -5869,3 +5869,44 @@ fn test_xor() -> Result<()> { } Ok(()) } + +#[test] +fn test_sign_operation() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Sign".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let mut inputs: HashMap<String, Tensor> = HashMap::new(); + inputs.insert( + INPUT_X.to_string(), + Tensor::new(vec![-2f32, -1., 0., 1., 2.], &Device::Cpu)?, + ); + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + assert_eq!( + z.to_dtype(candle::DType::I64)?.to_vec1::<i64>()?.to_vec(), + vec![-1, -1, 0, 1, 1] + ); + Ok(()) +} |