diff options
author | Gabriel <45515538+gabotechs@users.noreply.github.com> | 2024-04-20 18:44:22 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-20 18:44:22 +0200 |
commit | 9215e9ce8c3fbe2e2850065557fc7e37b8e1c948 (patch) | |
tree | 918cb90a70adc90fc45c1b3f6d69128a543071ea /candle-onnx | |
parent | 52ae33291060bb57ea2b7913179747040eed02b9 (diff) | |
download | candle-9215e9ce8c3fbe2e2850065557fc7e37b8e1c948.tar.gz candle-9215e9ce8c3fbe2e2850065557fc7e37b8e1c948.tar.bz2 candle-9215e9ce8c3fbe2e2850065557fc7e37b8e1c948.zip |
Add missing onnx operations (#2096)
* Add missing onnx operations
* Add tests and fix errors
* Run rustfmt
Diffstat (limited to 'candle-onnx')
-rw-r--r-- | candle-onnx/src/eval.rs | 160 | ||||
-rw-r--r-- | candle-onnx/tests/ops.rs | 585 |
2 files changed, 736 insertions, 9 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 75927822..417216d7 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -23,6 +23,11 @@ trait Attr { fn get(attr: &onnx::AttributeProto) -> Result<&Self>; } +trait AttrOwned: Sized { + const TYPE: AttributeType; + fn get(attr: &onnx::AttributeProto) -> Result<Self>; +} + impl Attr for i64 { const TYPE: AttributeType = AttributeType::Int; fn get(attr: &onnx::AttributeProto) -> Result<&Self> { @@ -51,6 +56,50 @@ impl Attr for str { } } +impl AttrOwned for Tensor { + const TYPE: AttributeType = AttributeType::Tensor; + fn get(attr: &onnx::AttributeProto) -> Result<Self> { + let tensor_proto = match &attr.t { + Some(value) => value, + None => bail!( + "attribute {} was of type TENSOR, but no tensor was found", + attr.name + ), + }; + + let data_type = match DataType::try_from(tensor_proto.data_type) { + Ok(value) => value, + Err(_) => bail!( + "attribute {} of type TENSOR was an invalid data_type number {}", + attr.name, + tensor_proto.data_type + ), + }; + + let dtype = match dtype(data_type) { + Some(value) => value, + None => bail!( + "attribute {} of type TENSOR has an unsupported data_type {}", + attr.name, + data_type.as_str_name() + ), + }; + + let mut dims = Vec::with_capacity(tensor_proto.dims.len()); + for dim in &tensor_proto.dims { + if dim < &0 { + bail!( + "attribute {} of type TENSOR has a negative dimension, which is unsupported", + attr.name + ) + } + dims.push(*dim as usize) + } + + Tensor::from_raw_buffer(&tensor_proto.raw_data, dtype, &dims, &Device::Cpu) + } +} + fn get_attr_<'a>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a onnx::AttributeProto> { match node.attribute.iter().find(|attr| attr.name == name) { None => { @@ -98,6 +147,24 @@ fn get_attr_opt<'a, T: Attr + ?Sized>( } } +fn get_attr_opt_owned<T: AttrOwned>(node: &onnx::NodeProto, name: &str) -> Result<Option<T>> { + match node.attribute.iter().find(|attr| attr.name == name) { + None => Ok(None), + Some(attr) => { + if attr.r#type() != T::TYPE { + bail!( + "unsupported type {:?} for '{name}' attribute in '{}' for {}", + attr.r#type, + node.op_type, + node.name + ) + } + let val = T::get(attr)?; + Ok(Some(val)) + } + } +} + pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> { let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect(); match DataType::try_from(t.data_type) { @@ -458,14 +525,17 @@ pub fn simple_eval( } values.insert(node.output[0].clone(), xs); } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConstantOfShape "ConstantOfShape" => { - let dims = get(&node.input[0])?; - let shape = dims - .to_vec1::<i64>()? - .into_iter() - .map(|v| v as usize) - .collect::<Vec<_>>(); - let xs = Tensor::zeros(shape, DType::F32, dims.device())?; + let input = get(&node.input[0])?; + let value = get_attr_opt_owned::<Tensor>(node, "value")?.unwrap_or(Tensor::zeros( + (), + DType::F32, + &Device::Cpu, + )?); + + let xs = Tensor::ones(input.shape(), value.dtype(), input.device())? + .broadcast_mul(&value)?; values.insert(node.output[0].clone(), xs); } "Unsqueeze" => { @@ -552,6 +622,82 @@ pub fn simple_eval( let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?; values.insert(node.output[0].clone(), dims); } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt + "Sqrt" => { + let xs = get(&node.input[0])?; + let output = xs.sqrt()?; + values.insert(node.output[0].clone(), output); + } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Range + "Range" => { + let start = get(&node.input[0])?; + let limit = get(&node.input[1])?; + let delta = get(&node.input[2])?; + + macro_rules! arange_step { + ($t: ty) => { + Tensor::arange_step( + start.to_vec0::<$t>()?, + limit.to_vec0::<$t>()?, + delta.to_vec0::<$t>()?, + &Device::Cpu, + )? + }; + } + + let output = match start.dtype() { + DType::U8 => arange_step!(u8), + DType::U32 => arange_step!(u32), + DType::I64 => arange_step!(i64), + DType::BF16 => arange_step!(f32), + DType::F16 => arange_step!(f32), + DType::F32 => arange_step!(f32), + DType::F64 => arange_step!(f64), + }; + + values.insert(node.output[0].clone(), output); + } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Greater + "Greater" => { + let a = get(&node.input[0])?; + let b = get(&node.input[1])?; + + let output = a.broadcast_gt(b)?; + values.insert(node.output[0].clone(), output); + } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Less + "Less" => { + let a = get(&node.input[0])?; + let b = get(&node.input[1])?; + + let output = a.broadcast_lt(b)?; + values.insert(node.output[0].clone(), output); + } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Log + "Log" => { + let a = get(&node.input[0])?; + + let output = a.log()?; + values.insert(node.output[0].clone(), output); + } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Min + "Min" => { + let mut output = get(&node.input[0])?.clone(); + for input in node.input.iter() { + let input = get(input)?; + output = output.broadcast_minimum(input)? + } + + values.insert(node.output[0].clone(), output); + } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Where + "Where" => { + let cond = get(&node.input[0])?; + let a = get(&node.input[1])?; + let b = get(&node.input[2])?; + let output = cond.where_cond(a, b)?; + values.insert(node.output[0].clone(), output); + } "Conv" => { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv let dilations = get_attr_opt::<[i64]>(node, "dilations")?; diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index fda76ec2..9b18170a 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -4,12 +4,16 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use candle::{Device, NdArray, Result, Tensor}; +use candle::{DType, Device, NdArray, Result, Tensor}; +use candle_onnx::onnx; +use candle_onnx::onnx::attribute_proto::AttributeType; +use candle_onnx::onnx::tensor_proto::DataType; use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto}; use std::collections::HashMap; const INPUT_X: &str = "x"; const INPUT_Y: &str = "y"; +const INPUT_A: &str = "a"; const OUTPUT_Z: &str = "z"; fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto { @@ -820,7 +824,137 @@ fn test_flatten_operation() -> Result<()> { // #[test] // "ConstantOfShape" -// #[test] +#[test] +fn test_constant_of_shape() -> Result<()> { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31 + test(&[4i64, 3, 2], Some(1.), &[1., 1., 1.])?; + + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31 + test(&[0.], Some(0i64), &[0i64])?; + + // "value" defaults to 0 f32 + test(&[1i64, 2, 3, 4], None as Option<i64>, &[0., 0., 0., 0.])?; + + fn test( + input: impl NdArray, + value: Option<impl NdArray>, + expected: impl NdArray, + ) -> Result<()> { + let mut attribute = vec![]; + + if let Some(value) = value { + let tensor = Tensor::new(value, &Device::Cpu)?; + + let (value, data_type) = match tensor.dtype() { + DType::U8 => ( + tensor.to_vec0::<u8>()?.to_le_bytes().to_vec(), + DataType::Uint8, + ), + DType::U32 => ( + tensor.to_vec0::<u32>()?.to_le_bytes().to_vec(), + DataType::Uint32, + ), + DType::I64 => ( + tensor.to_vec0::<i64>()?.to_le_bytes().to_vec(), + DataType::Int64, + ), + DType::F32 => ( + tensor.to_vec0::<f32>()?.to_le_bytes().to_vec(), + DataType::Float, + ), + DType::F64 => ( + tensor.to_vec0::<f64>()?.to_le_bytes().to_vec(), + DataType::Double, + ), + _ => panic!("unsupported DType in test"), + }; + let tensor = onnx::TensorProto { + data_type: data_type.into(), + dims: tensor.dims().iter().map(|v| *v as i64).collect(), + raw_data: value, + segment: None, + float_data: vec![], + int32_data: vec![], + string_data: vec![], + int64_data: vec![], + name: "".to_string(), + doc_string: "".to_string(), + external_data: vec![], + data_location: 0, + double_data: vec![], + uint64_data: vec![], + }; + + attribute.push(AttributeProto { + name: "value".to_string(), + ref_attr_name: "value".to_string(), + i: 0, + doc_string: "value".to_string(), + r#type: AttributeType::Tensor.into(), + f: 0.0, + s: vec![], + t: Some(tensor), + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }) + } + + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "ConstantOfShape".to_string(), + domain: "".to_string(), + attribute, + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let mut inputs: HashMap<String, Tensor> = HashMap::new(); + inputs.insert(INPUT_X.to_string(), Tensor::new(input, &Device::Cpu)?); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval + .get(OUTPUT_Z) + .expect("Output 'z' not found") + .to_dtype(DType::F64)?; + + let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?; + match expected.dims().len() { + 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?), + 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?), + 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?), + 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?), + _ => unreachable!(), + }; + + Ok(()) + } + Ok(()) +} // "Unsqueeze" // #[test] @@ -1639,3 +1773,450 @@ fn test_reduce_mean() -> Result<()> { Ok(()) } + +// "Sqrt" +#[test] +fn test_sqrt() -> Result<()> { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-155 + test(&[1., 4., 9.], &[1., 2., 3.])?; + + fn test(data: impl NdArray, expected: impl NdArray) -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Sqrt".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let mut inputs: HashMap<String, Tensor> = HashMap::new(); + inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let expected = Tensor::new(expected, &Device::Cpu)?; + match expected.dims().len() { + 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?), + 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?), + 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?), + 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?), + _ => unreachable!(), + }; + + Ok(()) + } + + Ok(()) +} + +// "Range" +#[test] +fn test_range() -> Result<()> { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-113 + test(1., 5., 2., &[1., 3.])?; + + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-113 + test(10i64, 6i64, -3i64, &[10i64, 7i64])?; + + fn test( + start: impl NdArray, + limit: impl NdArray, + delta: impl NdArray, + expected: impl NdArray, + ) -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Range".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![ + INPUT_X.to_string(), + INPUT_Y.to_string(), + INPUT_A.to_string(), + ], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let mut inputs: HashMap<String, Tensor> = HashMap::new(); + inputs.insert(INPUT_X.to_string(), Tensor::new(start, &Device::Cpu)?); + inputs.insert(INPUT_Y.to_string(), Tensor::new(limit, &Device::Cpu)?); + inputs.insert(INPUT_A.to_string(), Tensor::new(delta, &Device::Cpu)?); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval + .get(OUTPUT_Z) + .expect("Output 'z' not found") + .to_dtype(DType::F64)?; + + let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?; + match expected.dims().len() { + 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?), + 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?), + 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?), + 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?), + _ => unreachable!(), + }; + + Ok(()) + } + + Ok(()) +} + +// "Greater" +#[test] +fn test_greater() -> Result<()> { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-63 + test(&[1., 2., 3.], &[3., 2., 1.], &[0u8, 0, 1])?; + + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-63 + test(&[1., 2., 3.], 2., &[0u8, 0, 1])?; + + fn test(a: impl NdArray, b: impl NdArray, expected: impl NdArray) -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Greater".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let mut inputs: HashMap<String, Tensor> = HashMap::new(); + inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?); + inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval + .get(OUTPUT_Z) + .expect("Output 'z' not found") + .to_dtype(DType::F64)?; + + let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?; + match expected.dims().len() { + 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?), + 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?), + 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?), + 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?), + _ => unreachable!(), + }; + + Ok(()) + } + + Ok(()) +} + +// "Less" +#[test] +fn test_less() -> Result<()> { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-81 + test(&[1., 2., 3.], &[3., 2., 1.], &[1u8, 0, 0])?; + + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-81 + test(&[1., 2., 3.], 2., &[1u8, 0, 0])?; + + fn test(a: impl NdArray, b: impl NdArray, expected: impl NdArray) -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Less".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let mut inputs: HashMap<String, Tensor> = HashMap::new(); + inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?); + inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval + .get(OUTPUT_Z) + .expect("Output 'z' not found") + .to_dtype(DType::F64)?; + + let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?; + match expected.dims().len() { + 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?), + 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?), + 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?), + 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?), + _ => unreachable!(), + }; + + Ok(()) + } + + Ok(()) +} + +// "Log" +#[test] +fn test_log() -> Result<()> { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-82 + test(&[1., 10.], &[0., std::f64::consts::LN_10])?; + + fn test(data: impl NdArray, expected: impl NdArray) -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Log".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let mut inputs: HashMap<String, Tensor> = HashMap::new(); + inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let expected = Tensor::new(expected, &Device::Cpu)?; + match expected.dims().len() { + 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?), + 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?), + 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?), + 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?), + _ => unreachable!(), + }; + + Ok(()) + } + + Ok(()) +} + +// "Min" +#[test] +fn test_min() -> Result<()> { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-94 + test(&[3., 2., 1.], &[1., 4., 4.], &[2., 5., 0.], &[1., 2., 0.])?; + + fn test( + a: impl NdArray, + b: impl NdArray, + c: impl NdArray, + expected: impl NdArray, + ) -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Min".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![ + INPUT_X.to_string(), + INPUT_Y.to_string(), + INPUT_A.to_string(), + ], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let mut inputs: HashMap<String, Tensor> = HashMap::new(); + inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?); + inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?); + inputs.insert(INPUT_A.to_string(), Tensor::new(c, &Device::Cpu)?); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let expected = Tensor::new(expected, &Device::Cpu)?; + match expected.dims().len() { + 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?), + 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?), + 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?), + 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?), + _ => unreachable!(), + }; + + Ok(()) + } + + Ok(()) +} + +// "Where" +#[test] +fn test_where() -> Result<()> { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-173 + test( + &[[1u8, 0], [1, 1]], + &[[1i64, 2], [3, 4]], + &[[9i64, 8], [7, 6]], + &[[1i64, 8], [3, 4]], + )?; + + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-173 + test( + &[[1u8, 0], [1, 1]], + &[[1., 2.], [3., 4.]], + &[[9., 8.], [7., 6.]], + &[[1., 8.], [3., 4.]], + )?; + + fn test( + condition: impl NdArray, + x: impl NdArray, + y: impl NdArray, + expected: impl NdArray, + ) -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Where".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![ + INPUT_X.to_string(), + INPUT_Y.to_string(), + INPUT_A.to_string(), + ], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let mut inputs: HashMap<String, Tensor> = HashMap::new(); + inputs.insert(INPUT_X.to_string(), Tensor::new(condition, &Device::Cpu)?); + inputs.insert(INPUT_Y.to_string(), Tensor::new(x, &Device::Cpu)?); + inputs.insert(INPUT_A.to_string(), Tensor::new(y, &Device::Cpu)?); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval + .get(OUTPUT_Z) + .expect("Output 'z' not found") + .to_dtype(DType::F64)?; + + let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?; + match expected.dims().len() { + 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?), + 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?), + 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?), + 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?), + _ => unreachable!(), + }; + + Ok(()) + } + + Ok(()) +} |