summaryrefslogtreecommitdiff
path: root/candle-onnx
diff options
context:
space:
mode:
authorGabriel <45515538+gabotechs@users.noreply.github.com>2024-04-20 18:44:22 +0200
committerGitHub <noreply@github.com>2024-04-20 18:44:22 +0200
commit9215e9ce8c3fbe2e2850065557fc7e37b8e1c948 (patch)
tree918cb90a70adc90fc45c1b3f6d69128a543071ea /candle-onnx
parent52ae33291060bb57ea2b7913179747040eed02b9 (diff)
downloadcandle-9215e9ce8c3fbe2e2850065557fc7e37b8e1c948.tar.gz
candle-9215e9ce8c3fbe2e2850065557fc7e37b8e1c948.tar.bz2
candle-9215e9ce8c3fbe2e2850065557fc7e37b8e1c948.zip
Add missing onnx operations (#2096)
* Add missing onnx operations * Add tests and fix errors * Run rustfmt
Diffstat (limited to 'candle-onnx')
-rw-r--r--candle-onnx/src/eval.rs160
-rw-r--r--candle-onnx/tests/ops.rs585
2 files changed, 736 insertions, 9 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index 75927822..417216d7 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -23,6 +23,11 @@ trait Attr {
fn get(attr: &onnx::AttributeProto) -> Result<&Self>;
}
+trait AttrOwned: Sized {
+ const TYPE: AttributeType;
+ fn get(attr: &onnx::AttributeProto) -> Result<Self>;
+}
+
impl Attr for i64 {
const TYPE: AttributeType = AttributeType::Int;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
@@ -51,6 +56,50 @@ impl Attr for str {
}
}
+impl AttrOwned for Tensor {
+ const TYPE: AttributeType = AttributeType::Tensor;
+ fn get(attr: &onnx::AttributeProto) -> Result<Self> {
+ let tensor_proto = match &attr.t {
+ Some(value) => value,
+ None => bail!(
+ "attribute {} was of type TENSOR, but no tensor was found",
+ attr.name
+ ),
+ };
+
+ let data_type = match DataType::try_from(tensor_proto.data_type) {
+ Ok(value) => value,
+ Err(_) => bail!(
+ "attribute {} of type TENSOR was an invalid data_type number {}",
+ attr.name,
+ tensor_proto.data_type
+ ),
+ };
+
+ let dtype = match dtype(data_type) {
+ Some(value) => value,
+ None => bail!(
+ "attribute {} of type TENSOR has an unsupported data_type {}",
+ attr.name,
+ data_type.as_str_name()
+ ),
+ };
+
+ let mut dims = Vec::with_capacity(tensor_proto.dims.len());
+ for dim in &tensor_proto.dims {
+ if dim < &0 {
+ bail!(
+ "attribute {} of type TENSOR has a negative dimension, which is unsupported",
+ attr.name
+ )
+ }
+ dims.push(*dim as usize)
+ }
+
+ Tensor::from_raw_buffer(&tensor_proto.raw_data, dtype, &dims, &Device::Cpu)
+ }
+}
+
fn get_attr_<'a>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a onnx::AttributeProto> {
match node.attribute.iter().find(|attr| attr.name == name) {
None => {
@@ -98,6 +147,24 @@ fn get_attr_opt<'a, T: Attr + ?Sized>(
}
}
+fn get_attr_opt_owned<T: AttrOwned>(node: &onnx::NodeProto, name: &str) -> Result<Option<T>> {
+ match node.attribute.iter().find(|attr| attr.name == name) {
+ None => Ok(None),
+ Some(attr) => {
+ if attr.r#type() != T::TYPE {
+ bail!(
+ "unsupported type {:?} for '{name}' attribute in '{}' for {}",
+ attr.r#type,
+ node.op_type,
+ node.name
+ )
+ }
+ let val = T::get(attr)?;
+ Ok(Some(val))
+ }
+ }
+}
+
pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
match DataType::try_from(t.data_type) {
@@ -458,14 +525,17 @@ pub fn simple_eval(
}
values.insert(node.output[0].clone(), xs);
}
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConstantOfShape
"ConstantOfShape" => {
- let dims = get(&node.input[0])?;
- let shape = dims
- .to_vec1::<i64>()?
- .into_iter()
- .map(|v| v as usize)
- .collect::<Vec<_>>();
- let xs = Tensor::zeros(shape, DType::F32, dims.device())?;
+ let input = get(&node.input[0])?;
+ let value = get_attr_opt_owned::<Tensor>(node, "value")?.unwrap_or(Tensor::zeros(
+ (),
+ DType::F32,
+ &Device::Cpu,
+ )?);
+
+ let xs = Tensor::ones(input.shape(), value.dtype(), input.device())?
+ .broadcast_mul(&value)?;
values.insert(node.output[0].clone(), xs);
}
"Unsqueeze" => {
@@ -552,6 +622,82 @@ pub fn simple_eval(
let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
values.insert(node.output[0].clone(), dims);
}
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt
+ "Sqrt" => {
+ let xs = get(&node.input[0])?;
+ let output = xs.sqrt()?;
+ values.insert(node.output[0].clone(), output);
+ }
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Range
+ "Range" => {
+ let start = get(&node.input[0])?;
+ let limit = get(&node.input[1])?;
+ let delta = get(&node.input[2])?;
+
+ macro_rules! arange_step {
+ ($t: ty) => {
+ Tensor::arange_step(
+ start.to_vec0::<$t>()?,
+ limit.to_vec0::<$t>()?,
+ delta.to_vec0::<$t>()?,
+ &Device::Cpu,
+ )?
+ };
+ }
+
+ let output = match start.dtype() {
+ DType::U8 => arange_step!(u8),
+ DType::U32 => arange_step!(u32),
+ DType::I64 => arange_step!(i64),
+ DType::BF16 => arange_step!(f32),
+ DType::F16 => arange_step!(f32),
+ DType::F32 => arange_step!(f32),
+ DType::F64 => arange_step!(f64),
+ };
+
+ values.insert(node.output[0].clone(), output);
+ }
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Greater
+ "Greater" => {
+ let a = get(&node.input[0])?;
+ let b = get(&node.input[1])?;
+
+ let output = a.broadcast_gt(b)?;
+ values.insert(node.output[0].clone(), output);
+ }
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Less
+ "Less" => {
+ let a = get(&node.input[0])?;
+ let b = get(&node.input[1])?;
+
+ let output = a.broadcast_lt(b)?;
+ values.insert(node.output[0].clone(), output);
+ }
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Log
+ "Log" => {
+ let a = get(&node.input[0])?;
+
+ let output = a.log()?;
+ values.insert(node.output[0].clone(), output);
+ }
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Min
+ "Min" => {
+ let mut output = get(&node.input[0])?.clone();
+ for input in node.input.iter() {
+ let input = get(input)?;
+ output = output.broadcast_minimum(input)?
+ }
+
+ values.insert(node.output[0].clone(), output);
+ }
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Where
+ "Where" => {
+ let cond = get(&node.input[0])?;
+ let a = get(&node.input[1])?;
+ let b = get(&node.input[2])?;
+ let output = cond.where_cond(a, b)?;
+ values.insert(node.output[0].clone(), output);
+ }
"Conv" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs
index fda76ec2..9b18170a 100644
--- a/candle-onnx/tests/ops.rs
+++ b/candle-onnx/tests/ops.rs
@@ -4,12 +4,16 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
-use candle::{Device, NdArray, Result, Tensor};
+use candle::{DType, Device, NdArray, Result, Tensor};
+use candle_onnx::onnx;
+use candle_onnx::onnx::attribute_proto::AttributeType;
+use candle_onnx::onnx::tensor_proto::DataType;
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
use std::collections::HashMap;
const INPUT_X: &str = "x";
const INPUT_Y: &str = "y";
+const INPUT_A: &str = "a";
const OUTPUT_Z: &str = "z";
fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto {
@@ -820,7 +824,137 @@ fn test_flatten_operation() -> Result<()> {
// #[test]
// "ConstantOfShape"
-// #[test]
+#[test]
+fn test_constant_of_shape() -> Result<()> {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
+ test(&[4i64, 3, 2], Some(1.), &[1., 1., 1.])?;
+
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
+ test(&[0.], Some(0i64), &[0i64])?;
+
+ // "value" defaults to 0 f32
+ test(&[1i64, 2, 3, 4], None as Option<i64>, &[0., 0., 0., 0.])?;
+
+ fn test(
+ input: impl NdArray,
+ value: Option<impl NdArray>,
+ expected: impl NdArray,
+ ) -> Result<()> {
+ let mut attribute = vec![];
+
+ if let Some(value) = value {
+ let tensor = Tensor::new(value, &Device::Cpu)?;
+
+ let (value, data_type) = match tensor.dtype() {
+ DType::U8 => (
+ tensor.to_vec0::<u8>()?.to_le_bytes().to_vec(),
+ DataType::Uint8,
+ ),
+ DType::U32 => (
+ tensor.to_vec0::<u32>()?.to_le_bytes().to_vec(),
+ DataType::Uint32,
+ ),
+ DType::I64 => (
+ tensor.to_vec0::<i64>()?.to_le_bytes().to_vec(),
+ DataType::Int64,
+ ),
+ DType::F32 => (
+ tensor.to_vec0::<f32>()?.to_le_bytes().to_vec(),
+ DataType::Float,
+ ),
+ DType::F64 => (
+ tensor.to_vec0::<f64>()?.to_le_bytes().to_vec(),
+ DataType::Double,
+ ),
+ _ => panic!("unsupported DType in test"),
+ };
+ let tensor = onnx::TensorProto {
+ data_type: data_type.into(),
+ dims: tensor.dims().iter().map(|v| *v as i64).collect(),
+ raw_data: value,
+ segment: None,
+ float_data: vec![],
+ int32_data: vec![],
+ string_data: vec![],
+ int64_data: vec![],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ external_data: vec![],
+ data_location: 0,
+ double_data: vec![],
+ uint64_data: vec![],
+ };
+
+ attribute.push(AttributeProto {
+ name: "value".to_string(),
+ ref_attr_name: "value".to_string(),
+ i: 0,
+ doc_string: "value".to_string(),
+ r#type: AttributeType::Tensor.into(),
+ f: 0.0,
+ s: vec![],
+ t: Some(tensor),
+ g: None,
+ sparse_tensor: None,
+ tp: None,
+ floats: vec![],
+ ints: vec![],
+ strings: vec![],
+ tensors: vec![],
+ graphs: vec![],
+ sparse_tensors: vec![],
+ type_protos: vec![],
+ })
+ }
+
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "ConstantOfShape".to_string(),
+ domain: "".to_string(),
+ attribute,
+ input: vec![INPUT_X.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(input, &Device::Cpu)?);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval
+ .get(OUTPUT_Z)
+ .expect("Output 'z' not found")
+ .to_dtype(DType::F64)?;
+
+ let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
+ match expected.dims().len() {
+ 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
+ 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
+ 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
+ 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
+ _ => unreachable!(),
+ };
+
+ Ok(())
+ }
+ Ok(())
+}
// "Unsqueeze"
// #[test]
@@ -1639,3 +1773,450 @@ fn test_reduce_mean() -> Result<()> {
Ok(())
}
+
+// "Sqrt"
+#[test]
+fn test_sqrt() -> Result<()> {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-155
+ test(&[1., 4., 9.], &[1., 2., 3.])?;
+
+ fn test(data: impl NdArray, expected: impl NdArray) -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Sqrt".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![INPUT_X.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+
+ let expected = Tensor::new(expected, &Device::Cpu)?;
+ match expected.dims().len() {
+ 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
+ 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
+ 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
+ 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
+ _ => unreachable!(),
+ };
+
+ Ok(())
+ }
+
+ Ok(())
+}
+
+// "Range"
+#[test]
+fn test_range() -> Result<()> {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-113
+ test(1., 5., 2., &[1., 3.])?;
+
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-113
+ test(10i64, 6i64, -3i64, &[10i64, 7i64])?;
+
+ fn test(
+ start: impl NdArray,
+ limit: impl NdArray,
+ delta: impl NdArray,
+ expected: impl NdArray,
+ ) -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Range".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![
+ INPUT_X.to_string(),
+ INPUT_Y.to_string(),
+ INPUT_A.to_string(),
+ ],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(start, &Device::Cpu)?);
+ inputs.insert(INPUT_Y.to_string(), Tensor::new(limit, &Device::Cpu)?);
+ inputs.insert(INPUT_A.to_string(), Tensor::new(delta, &Device::Cpu)?);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval
+ .get(OUTPUT_Z)
+ .expect("Output 'z' not found")
+ .to_dtype(DType::F64)?;
+
+ let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
+ match expected.dims().len() {
+ 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
+ 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
+ 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
+ 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
+ _ => unreachable!(),
+ };
+
+ Ok(())
+ }
+
+ Ok(())
+}
+
+// "Greater"
+#[test]
+fn test_greater() -> Result<()> {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-63
+ test(&[1., 2., 3.], &[3., 2., 1.], &[0u8, 0, 1])?;
+
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-63
+ test(&[1., 2., 3.], 2., &[0u8, 0, 1])?;
+
+ fn test(a: impl NdArray, b: impl NdArray, expected: impl NdArray) -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Greater".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?);
+ inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval
+ .get(OUTPUT_Z)
+ .expect("Output 'z' not found")
+ .to_dtype(DType::F64)?;
+
+ let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
+ match expected.dims().len() {
+ 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
+ 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
+ 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
+ 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
+ _ => unreachable!(),
+ };
+
+ Ok(())
+ }
+
+ Ok(())
+}
+
+// "Less"
+#[test]
+fn test_less() -> Result<()> {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-81
+ test(&[1., 2., 3.], &[3., 2., 1.], &[1u8, 0, 0])?;
+
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-81
+ test(&[1., 2., 3.], 2., &[1u8, 0, 0])?;
+
+ fn test(a: impl NdArray, b: impl NdArray, expected: impl NdArray) -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Less".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?);
+ inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval
+ .get(OUTPUT_Z)
+ .expect("Output 'z' not found")
+ .to_dtype(DType::F64)?;
+
+ let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
+ match expected.dims().len() {
+ 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
+ 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
+ 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
+ 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
+ _ => unreachable!(),
+ };
+
+ Ok(())
+ }
+
+ Ok(())
+}
+
+// "Log"
+#[test]
+fn test_log() -> Result<()> {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-82
+ test(&[1., 10.], &[0., std::f64::consts::LN_10])?;
+
+ fn test(data: impl NdArray, expected: impl NdArray) -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Log".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![INPUT_X.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+
+ let expected = Tensor::new(expected, &Device::Cpu)?;
+ match expected.dims().len() {
+ 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
+ 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
+ 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
+ 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
+ _ => unreachable!(),
+ };
+
+ Ok(())
+ }
+
+ Ok(())
+}
+
+// "Min"
+#[test]
+fn test_min() -> Result<()> {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-94
+ test(&[3., 2., 1.], &[1., 4., 4.], &[2., 5., 0.], &[1., 2., 0.])?;
+
+ fn test(
+ a: impl NdArray,
+ b: impl NdArray,
+ c: impl NdArray,
+ expected: impl NdArray,
+ ) -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Min".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![
+ INPUT_X.to_string(),
+ INPUT_Y.to_string(),
+ INPUT_A.to_string(),
+ ],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?);
+ inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?);
+ inputs.insert(INPUT_A.to_string(), Tensor::new(c, &Device::Cpu)?);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+
+ let expected = Tensor::new(expected, &Device::Cpu)?;
+ match expected.dims().len() {
+ 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
+ 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
+ 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
+ 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
+ _ => unreachable!(),
+ };
+
+ Ok(())
+ }
+
+ Ok(())
+}
+
+// "Where"
+#[test]
+fn test_where() -> Result<()> {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-173
+ test(
+ &[[1u8, 0], [1, 1]],
+ &[[1i64, 2], [3, 4]],
+ &[[9i64, 8], [7, 6]],
+ &[[1i64, 8], [3, 4]],
+ )?;
+
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-173
+ test(
+ &[[1u8, 0], [1, 1]],
+ &[[1., 2.], [3., 4.]],
+ &[[9., 8.], [7., 6.]],
+ &[[1., 8.], [3., 4.]],
+ )?;
+
+ fn test(
+ condition: impl NdArray,
+ x: impl NdArray,
+ y: impl NdArray,
+ expected: impl NdArray,
+ ) -> Result<()> {
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Where".to_string(),
+ domain: "".to_string(),
+ attribute: vec![],
+ input: vec![
+ INPUT_X.to_string(),
+ INPUT_Y.to_string(),
+ INPUT_A.to_string(),
+ ],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(condition, &Device::Cpu)?);
+ inputs.insert(INPUT_Y.to_string(), Tensor::new(x, &Device::Cpu)?);
+ inputs.insert(INPUT_A.to_string(), Tensor::new(y, &Device::Cpu)?);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval
+ .get(OUTPUT_Z)
+ .expect("Output 'z' not found")
+ .to_dtype(DType::F64)?;
+
+ let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
+ match expected.dims().len() {
+ 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
+ 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
+ 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
+ 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
+ _ => unreachable!(),
+ };
+
+ Ok(())
+ }
+
+ Ok(())
+}