diff options
author | shua <itis@isthisa.email> | 2024-06-06 22:36:23 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-06 22:36:23 +0200 |
commit | b9fac7ec008bfccf8900552f51e6d0e865280ee9 (patch) | |
tree | dca75fc04462bbf1c990a93c2170a40cd92eec62 | |
parent | f65e90e7efd611c582c4e7dd7f7b387b74c61111 (diff) | |
download | candle-b9fac7ec008bfccf8900552f51e6d0e865280ee9.tar.gz candle-b9fac7ec008bfccf8900552f51e6d0e865280ee9.tar.bz2 candle-b9fac7ec008bfccf8900552f51e6d0e865280ee9.zip |
implement if, and pad reflect mode (#2251)
* implement if, and pad reflect mode
The intent of this change is to allow eval of the current silero_vad.onnx (v4).
This onnx file uses 'If' and 'Pad' nodes, which had not been supported
by simple_eval until now
* Cleanup (fmt, clippy, minor test tweaks).
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
-rw-r--r-- | candle-onnx/src/eval.rs | 101 | ||||
-rw-r--r-- | candle-onnx/tests/ops.rs | 224 |
2 files changed, 271 insertions, 54 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index e72002e6..f52e6c5c 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -1,6 +1,6 @@ -use crate::onnx; use crate::onnx::attribute_proto::AttributeType; use crate::onnx::tensor_proto::DataType; +use crate::onnx::{self, GraphProto}; use candle::{bail, DType, Device, Result, Tensor}; use std::{collections::HashMap, usize}; @@ -56,6 +56,15 @@ impl Attr for str { } } +impl Attr for GraphProto { + const TYPE: AttributeType = AttributeType::Graph; + fn get(attr: &onnx::AttributeProto) -> Result<&Self> { + attr.g + .as_ref() + .ok_or_else(|| candle::Error::Msg("attribute does not contain graph".to_string())) + } +} + impl AttrOwned for Tensor { const TYPE: AttributeType = AttributeType::Tensor; fn get(attr: &onnx::AttributeProto) -> Result<Self> { @@ -214,13 +223,19 @@ pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> { // anymore. pub fn simple_eval( model: &onnx::ModelProto, - inputs: HashMap<String, Value>, + mut inputs: HashMap<String, Value>, ) -> Result<HashMap<String, Value>> { let graph = match &model.graph { None => bail!("no graph defined in proto"), Some(graph) => graph, }; - let mut values = inputs; + simple_eval_(graph, &mut inputs) +} + +fn simple_eval_( + graph: &onnx::GraphProto, + values: &mut HashMap<String, Value>, +) -> Result<HashMap<String, Value>> { for t in graph.initializer.iter() { let tensor = get_tensor(t, t.name.as_str())?; values.insert(t.name.to_string(), tensor); @@ -958,6 +973,86 @@ pub fn simple_eval( let input = get(&node.input[0])?; values.insert(node.output[0].clone(), input.clone()); } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#if + "If" => { + // protobuf encodes boolean false as 0 and true as 1 + let cond = get(&node.input[0])?.get(0)?.to_scalar::<u8>()?; + let attr_name = if cond != 0 { + "then_branch" + } else { + "else_branch" + }; + let sub_graph = get_attr::<GraphProto>(node, attr_name)?; + if sub_graph.output.len() != node.output.len() { + bail!( + "If node {:?} is malformed: branch outputs ({}) don't match node outputs ({})", + node.name, + sub_graph.output.len(), + node.output.len() + ); + } + let branch_out = simple_eval_(sub_graph, values)?; + for (i, out) in node.output.iter().enumerate() { + values.insert( + out.clone(), + branch_out.get(&sub_graph.output[i].name).unwrap().clone(), + ); + } + } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#pad + "Pad" => { + let mode = get_attr_opt(node, "mode")?.unwrap_or("constant"); + let data = get(&node.input[0])?; + let pads = get(&node.input[1])?; + if node.input.len() > 2 { + bail!( + "unsupported number of inputs {} for Pad node {:?}, expected 2", + node.input.len(), + node.name + ); + } + if pads.rank() != 1 { + bail!("Pad expects 'pads' input to be 1D vector: {pads:?}"); + } + if pads.dim(0).unwrap() != 2 * data.rank() { + bail!("Pad expects 'pads' input len to be 2 * rank of 'data' input: pads: {}, data rank: {}", pads, data.rank()); + } + + let pads = pads.to_vec1::<i64>()?; + let (pads_pre, pads_post) = pads.split_at(pads.len() / 2); + + match mode { + "reflect" => { + let mut out = data.clone(); + for (i, &dim) in data.dims().iter().enumerate().rev() { + if pads_pre[i] == 0 && pads_post[i] == 0 { + continue; + } + fn zigzag(min: i64, max: i64) -> impl Iterator<Item = i64> { + std::iter::repeat((min..max).chain((min + 1..=max).rev())).flatten() + } + let idx = if dim > 1 { + let cycle_len = dim * 2 - 1; + let skip = (pads_pre[i] as usize) % cycle_len; + let idx = zigzag(0, (dim - 1) as i64) + .skip(skip) + .take((pads_pre[i] as usize) + dim + (pads_post[i] as usize)); + Tensor::from_iter(idx, out.device())? + } else { + Tensor::full(0i64, (dim,), out.device())? + }; + + out = out.index_select(&idx, i)?; + } + + values.insert(node.output[0].clone(), out); + } + _ => bail!( + "unsupported 'mode' value {mode:?} for Pad node {:?}", + node.name + ), + } + } // https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13 // TODO: This version is only compatible with ReduceMean V13 and below. "ReduceMean" => { diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index 2e60d22c..b4299af1 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -4,10 +4,12 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; +use candle::test_utils::to_vec2_round; use candle::{DType, Device, NdArray, Result, Tensor}; -use candle_onnx::onnx; use candle_onnx::onnx::attribute_proto::AttributeType; use candle_onnx::onnx::tensor_proto::DataType; +use candle_onnx::onnx::tensor_shape_proto::{dimension, Dimension}; +use candle_onnx::onnx::{type_proto, TensorProto, TensorShapeProto, TypeProto}; use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto}; use std::collections::HashMap; @@ -35,14 +37,11 @@ fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto { #[test] fn test_evaluation_fails_without_defined_graph() -> Result<()> { let manual_graph = create_model_proto_with_graph(None); - let inputs: HashMap<String, Tensor> = HashMap::new(); - match candle_onnx::simple_eval(&manual_graph, inputs) { Err(err) => assert_eq!(err.to_string(), "no graph defined in proto"), Ok(_) => panic!("Expected an error due to undefined graph"), } - Ok(()) } @@ -81,14 +80,8 @@ fn test_add_operation() -> Result<()> { assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); - let first = z - .to_vec1::<f64>()? - .to_vec() - .get(0) - .expect("Failed to get first element") - .clone(); + let first = z.to_vec1::<f64>()?[0]; assert_eq!(first, 4.0f64); - Ok(()) } @@ -127,14 +120,8 @@ fn test_sub_operation() -> Result<()> { assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); - let first = z - .to_vec1::<f64>()? - .to_vec() - .get(0) - .expect("Failed to get first element") - .clone(); + let first = z.to_vec1::<f64>()?[0]; assert_eq!(first, 0.0f64); - Ok(()) } @@ -173,14 +160,8 @@ fn test_mul_operation() -> Result<()> { assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); - let first = z - .to_vec1::<f64>()? - .to_vec() - .get(0) - .expect("Failed to get first element") - .clone(); + let first = z.to_vec1::<f64>()?[0]; assert_eq!(first, 4.0f64); - Ok(()) } @@ -219,15 +200,8 @@ fn test_div_operation() -> Result<()> { assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); - let first = z - .to_vec1::<f64>()? - .to_vec() - .get(0) - .expect("Failed to get first element") - .clone(); - + let first = z.to_vec1::<f64>()?[0]; assert_eq!(first, 1.0f64); - Ok(()) } @@ -272,7 +246,7 @@ fn test_exp_operation() -> Result<()> { assert_eq!(results[0][0], 0.36787944f32); assert_eq!(results[0][1], 1.0f32); - assert_eq!(results[1], vec![std::f32::consts::E, 7.38905609f32]); + assert_eq!(results[1], vec![std::f32::consts::E, 7.389056f32]); Ok(()) } @@ -914,7 +888,7 @@ fn test_constant_of_shape() -> Result<()> { ), _ => panic!("unsupported DType in test"), }; - let tensor = onnx::TensorProto { + let tensor = TensorProto { data_type: data_type.into(), dims: tensor.dims().iter().map(|v| *v as i64).collect(), raw_data: value, @@ -1293,14 +1267,7 @@ fn test_cos_operation() -> Result<()> { assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); - - let results = z.to_vec2::<f32>()?; - - assert_eq!( - results, - vec![vec![1.0, 0.54030234], vec![-0.41614684, -0.9899925]] - ); - + assert_eq!(to_vec2_round(z, 4)?, [[1.0, 0.5403], [-0.4161, -0.99]]); Ok(()) } @@ -1342,19 +1309,12 @@ fn test_sin_operation() -> Result<()> { quantization_annotation: vec![], })); let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?; - let mut inputs: HashMap<String, Tensor> = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); - let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); - let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); - - let results = z.to_vec2::<f32>()?; - - assert_eq!(results, vec![vec![0.0, 0.841471], vec![0.9092974, 0.14112]]); - + assert_eq!(to_vec2_round(z, 4)?, [[0.0, 0.8415], [0.9093, 0.1411]]); Ok(()) } @@ -3150,3 +3110,165 @@ fn test_leakyrelu() -> Result<()> { Ok(()) } + +// "If" +#[test] +fn test_if() -> Result<()> { + let x = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + let y = vec![5.0, 4.0, 3.0, 2.0, 1.0]; + let output_type_proto = Some(TypeProto { + value: Some(type_proto::Value::TensorType(type_proto::Tensor { + elem_type: DataType::Float.into(), + shape: Some(TensorShapeProto { + dim: vec![Dimension { + denotation: "".to_string(), + value: Some(dimension::Value::DimValue(5)), + }], + }), + })), + denotation: "".to_string(), + }); + let then_branch = GraphProto { + output: vec![ValueInfoProto { + name: "then_out".to_string(), + r#type: output_type_proto.clone(), + doc_string: "".to_string(), + }], + node: vec![NodeProto { + op_type: "Constant".to_string(), + input: vec![], + output: vec!["then_out".to_string()], + attribute: vec![AttributeProto { + name: "value".to_string(), + r#type: AttributeType::Tensor.into(), + t: Some(TensorProto { + dims: vec![x.len() as i64], + float_data: x.clone(), + data_type: DataType::Float.into(), + ..TensorProto::default() + }), + ..AttributeProto::default() + }], + ..NodeProto::default() + }], + ..GraphProto::default() + }; + let else_branch = GraphProto { + output: vec![ValueInfoProto { + name: "else_out".to_string(), + r#type: output_type_proto.clone(), + doc_string: "".to_string(), + }], + node: vec![NodeProto { + op_type: "Constant".to_string(), + input: vec![], + output: vec!["else_out".to_string()], + attribute: vec![AttributeProto { + name: "value".to_string(), + r#type: AttributeType::Tensor.into(), + t: Some(TensorProto { + dims: vec![y.len() as i64], + float_data: y.clone(), + data_type: DataType::Float.into(), + ..TensorProto::default() + }), + ..AttributeProto::default() + }], + ..NodeProto::default() + }], + ..GraphProto::default() + }; + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "If".to_string(), + attribute: vec![ + AttributeProto { + name: "then_branch".to_string(), + r#type: AttributeType::Graph.into(), + g: Some(then_branch), + ..AttributeProto::default() + }, + AttributeProto { + name: "else_branch".to_string(), + r#type: AttributeType::Graph.into(), + g: Some(else_branch), + ..AttributeProto::default() + }, + ], + input: vec!["cond".to_string()], + output: vec!["res".to_string()], + ..NodeProto::default() + }], + input: vec![], + output: vec![ValueInfoProto { + name: "res".to_string(), + doc_string: "".to_string(), + r#type: output_type_proto.clone(), + }], + ..GraphProto::default() + })); + + for cond in [1u8, 0] { + let inputs = + HashMap::from_iter([("cond".to_string(), Tensor::full(cond, (1,), &Device::Cpu)?)]); + let outputs = candle_onnx::simple_eval(&manual_graph, inputs)?; + let expected = if cond != 0 { &x } else { &y }; + let Some(res) = outputs.get("res") else { + candle::bail!("outputs didn't contain expected key `res`: {outputs:?}"); + }; + assert_eq!(&res.to_vec1::<f32>()?, expected); + } + Ok(()) +} + +#[test] +fn test_pad() -> Result<()> { + let data = Tensor::from_vec(vec![1.0, 1.2, 2.3, 3.4, 4.5, 5.7], (3, 2), &Device::Cpu)?; + let pads = Tensor::from_vec(vec![0i64, 2, 0, 0], (4,), &Device::Cpu)?; + let mode = "reflect"; + + let expected = Tensor::from_vec( + vec![1.0, 1.2, 1.0, 1.2, 2.3, 3.4, 2.3, 3.4, 4.5, 5.7, 4.5, 5.7], + (3, 4), + &Device::Cpu, + )?; + + let model = create_model_proto_with_graph(Some(GraphProto { + input: vec![ + ValueInfoProto { + name: "data".to_string(), + ..ValueInfoProto::default() + }, + ValueInfoProto { + name: "pads".to_string(), + ..ValueInfoProto::default() + }, + ], + output: vec![ValueInfoProto { + name: "output".to_string(), + ..ValueInfoProto::default() + }], + node: vec![NodeProto { + op_type: "Pad".to_string(), + input: vec!["data".to_string(), "pads".to_string()], + output: vec!["output".to_string()], + attribute: vec![AttributeProto { + name: "mode".to_string(), + r#type: AttributeType::String.into(), + s: mode.as_bytes().to_vec(), + ..AttributeProto::default() + }], + ..NodeProto::default() + }], + ..GraphProto::default() + })); + + let inputs = HashMap::from_iter([("data".to_string(), data), ("pads".to_string(), pads)]); + let res = candle_onnx::simple_eval(&model, inputs)?; + let Some(actual) = res.get("output") else { + candle::bail!("outputs didn't contain expected key `output`: {res:?}"); + }; + + assert_eq!(actual.to_vec2::<f64>()?, expected.to_vec2::<f64>()?); + Ok(()) +} |