summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIonut Mihalcea <ionut.mihalcea@gmail.com>2024-11-26 23:10:09 +0100
committerGitHub <noreply@github.com>2024-11-26 23:10:09 +0100
commit21c686387cead049aad32e6d1cc494d6c79e46e3 (patch)
treeb055d64d2bb5ac994f4bb26c67f19d3c258a5035
parentb4deb5c5a9fc6287f7521e6bc2b7f3c2d56510dd (diff)
downloadcandle-21c686387cead049aad32e6d1cc494d6c79e46e3.tar.gz
candle-21c686387cead049aad32e6d1cc494d6c79e46e3.tar.bz2
candle-21c686387cead049aad32e6d1cc494d6c79e46e3.zip
Onnx Support for Sign operation #2641 (#2642)
* Support for Sign operation #2641 * Apply rustfmt. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
-rw-r--r--candle-onnx/src/eval.rs6
-rw-r--r--candle-onnx/tests/ops.rs41
2 files changed, 47 insertions, 0 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index 358af7ac..2c60ed2f 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -1944,6 +1944,12 @@ fn simple_eval_(
values.insert(node.output[0].clone(), out);
}
+ // https://onnx.ai/onnx/operators/onnx__Sign.html
+ "Sign" => {
+ let input = get(&node.input[0])?;
+ let output = input.sign()?;
+ values.insert(node.output[0].clone(), output);
+ }
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}
diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs
index a84ba481..3586bfbd 100644
--- a/candle-onnx/tests/ops.rs
+++ b/candle-onnx/tests/ops.rs
@@ -5869,3 +5869,44 @@ fn test_xor() -> Result<()> {
}
Ok(())
}
+
+#[test]
+fn test_sign_operation() -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Sign".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![INPUT_X.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(
+ INPUT_X.to_string(),
+ Tensor::new(vec![-2f32, -1., 0., 1., 2.], &Device::Cpu)?,
+ );
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+ assert_eq!(
+ z.to_dtype(candle::DType::I64)?.to_vec1::<i64>()?.to_vec(),
+ vec![-1, -1, 0, 1, 1]
+ );
+ Ok(())
+}