summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorshua <itis@isthisa.email>2024-06-06 22:36:23 +0200
committerGitHub <noreply@github.com>2024-06-06 22:36:23 +0200
commitb9fac7ec008bfccf8900552f51e6d0e865280ee9 (patch)
treedca75fc04462bbf1c990a93c2170a40cd92eec62
parentf65e90e7efd611c582c4e7dd7f7b387b74c61111 (diff)
downloadcandle-b9fac7ec008bfccf8900552f51e6d0e865280ee9.tar.gz
candle-b9fac7ec008bfccf8900552f51e6d0e865280ee9.tar.bz2
candle-b9fac7ec008bfccf8900552f51e6d0e865280ee9.zip
implement if, and pad reflect mode (#2251)
* implement if, and pad reflect mode The intent of this change is to allow eval of the current silero_vad.onnx (v4). This onnx file uses 'If' and 'Pad' nodes, which had not been supported by simple_eval until now * Cleanup (fmt, clippy, minor test tweaks). --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
-rw-r--r--candle-onnx/src/eval.rs101
-rw-r--r--candle-onnx/tests/ops.rs224
2 files changed, 271 insertions, 54 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index e72002e6..f52e6c5c 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -1,6 +1,6 @@
-use crate::onnx;
use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
+use crate::onnx::{self, GraphProto};
use candle::{bail, DType, Device, Result, Tensor};
use std::{collections::HashMap, usize};
@@ -56,6 +56,15 @@ impl Attr for str {
}
}
+impl Attr for GraphProto {
+ const TYPE: AttributeType = AttributeType::Graph;
+ fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
+ attr.g
+ .as_ref()
+ .ok_or_else(|| candle::Error::Msg("attribute does not contain graph".to_string()))
+ }
+}
+
impl AttrOwned for Tensor {
const TYPE: AttributeType = AttributeType::Tensor;
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
@@ -214,13 +223,19 @@ pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
// anymore.
pub fn simple_eval(
model: &onnx::ModelProto,
- inputs: HashMap<String, Value>,
+ mut inputs: HashMap<String, Value>,
) -> Result<HashMap<String, Value>> {
let graph = match &model.graph {
None => bail!("no graph defined in proto"),
Some(graph) => graph,
};
- let mut values = inputs;
+ simple_eval_(graph, &mut inputs)
+}
+
+fn simple_eval_(
+ graph: &onnx::GraphProto,
+ values: &mut HashMap<String, Value>,
+) -> Result<HashMap<String, Value>> {
for t in graph.initializer.iter() {
let tensor = get_tensor(t, t.name.as_str())?;
values.insert(t.name.to_string(), tensor);
@@ -958,6 +973,86 @@ pub fn simple_eval(
let input = get(&node.input[0])?;
values.insert(node.output[0].clone(), input.clone());
}
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#if
+ "If" => {
+ // protobuf encodes boolean false as 0 and true as 1
+ let cond = get(&node.input[0])?.get(0)?.to_scalar::<u8>()?;
+ let attr_name = if cond != 0 {
+ "then_branch"
+ } else {
+ "else_branch"
+ };
+ let sub_graph = get_attr::<GraphProto>(node, attr_name)?;
+ if sub_graph.output.len() != node.output.len() {
+ bail!(
+ "If node {:?} is malformed: branch outputs ({}) don't match node outputs ({})",
+ node.name,
+ sub_graph.output.len(),
+ node.output.len()
+ );
+ }
+ let branch_out = simple_eval_(sub_graph, values)?;
+ for (i, out) in node.output.iter().enumerate() {
+ values.insert(
+ out.clone(),
+ branch_out.get(&sub_graph.output[i].name).unwrap().clone(),
+ );
+ }
+ }
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#pad
+ "Pad" => {
+ let mode = get_attr_opt(node, "mode")?.unwrap_or("constant");
+ let data = get(&node.input[0])?;
+ let pads = get(&node.input[1])?;
+ if node.input.len() > 2 {
+ bail!(
+ "unsupported number of inputs {} for Pad node {:?}, expected 2",
+ node.input.len(),
+ node.name
+ );
+ }
+ if pads.rank() != 1 {
+ bail!("Pad expects 'pads' input to be 1D vector: {pads:?}");
+ }
+ if pads.dim(0).unwrap() != 2 * data.rank() {
+ bail!("Pad expects 'pads' input len to be 2 * rank of 'data' input: pads: {}, data rank: {}", pads, data.rank());
+ }
+
+ let pads = pads.to_vec1::<i64>()?;
+ let (pads_pre, pads_post) = pads.split_at(pads.len() / 2);
+
+ match mode {
+ "reflect" => {
+ let mut out = data.clone();
+ for (i, &dim) in data.dims().iter().enumerate().rev() {
+ if pads_pre[i] == 0 && pads_post[i] == 0 {
+ continue;
+ }
+ fn zigzag(min: i64, max: i64) -> impl Iterator<Item = i64> {
+ std::iter::repeat((min..max).chain((min + 1..=max).rev())).flatten()
+ }
+ let idx = if dim > 1 {
+ let cycle_len = dim * 2 - 1;
+ let skip = (pads_pre[i] as usize) % cycle_len;
+ let idx = zigzag(0, (dim - 1) as i64)
+ .skip(skip)
+ .take((pads_pre[i] as usize) + dim + (pads_post[i] as usize));
+ Tensor::from_iter(idx, out.device())?
+ } else {
+ Tensor::full(0i64, (dim,), out.device())?
+ };
+
+ out = out.index_select(&idx, i)?;
+ }
+
+ values.insert(node.output[0].clone(), out);
+ }
+ _ => bail!(
+ "unsupported 'mode' value {mode:?} for Pad node {:?}",
+ node.name
+ ),
+ }
+ }
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
// TODO: This version is only compatible with ReduceMean V13 and below.
"ReduceMean" => {
diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs
index 2e60d22c..b4299af1 100644
--- a/candle-onnx/tests/ops.rs
+++ b/candle-onnx/tests/ops.rs
@@ -4,10 +4,12 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
+use candle::test_utils::to_vec2_round;
use candle::{DType, Device, NdArray, Result, Tensor};
-use candle_onnx::onnx;
use candle_onnx::onnx::attribute_proto::AttributeType;
use candle_onnx::onnx::tensor_proto::DataType;
+use candle_onnx::onnx::tensor_shape_proto::{dimension, Dimension};
+use candle_onnx::onnx::{type_proto, TensorProto, TensorShapeProto, TypeProto};
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
use std::collections::HashMap;
@@ -35,14 +37,11 @@ fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto {
#[test]
fn test_evaluation_fails_without_defined_graph() -> Result<()> {
let manual_graph = create_model_proto_with_graph(None);
-
let inputs: HashMap<String, Tensor> = HashMap::new();
-
match candle_onnx::simple_eval(&manual_graph, inputs) {
Err(err) => assert_eq!(err.to_string(), "no graph defined in proto"),
Ok(_) => panic!("Expected an error due to undefined graph"),
}
-
Ok(())
}
@@ -81,14 +80,8 @@ fn test_add_operation() -> Result<()> {
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
- let first = z
- .to_vec1::<f64>()?
- .to_vec()
- .get(0)
- .expect("Failed to get first element")
- .clone();
+ let first = z.to_vec1::<f64>()?[0];
assert_eq!(first, 4.0f64);
-
Ok(())
}
@@ -127,14 +120,8 @@ fn test_sub_operation() -> Result<()> {
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
- let first = z
- .to_vec1::<f64>()?
- .to_vec()
- .get(0)
- .expect("Failed to get first element")
- .clone();
+ let first = z.to_vec1::<f64>()?[0];
assert_eq!(first, 0.0f64);
-
Ok(())
}
@@ -173,14 +160,8 @@ fn test_mul_operation() -> Result<()> {
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
- let first = z
- .to_vec1::<f64>()?
- .to_vec()
- .get(0)
- .expect("Failed to get first element")
- .clone();
+ let first = z.to_vec1::<f64>()?[0];
assert_eq!(first, 4.0f64);
-
Ok(())
}
@@ -219,15 +200,8 @@ fn test_div_operation() -> Result<()> {
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
- let first = z
- .to_vec1::<f64>()?
- .to_vec()
- .get(0)
- .expect("Failed to get first element")
- .clone();
-
+ let first = z.to_vec1::<f64>()?[0];
assert_eq!(first, 1.0f64);
-
Ok(())
}
@@ -272,7 +246,7 @@ fn test_exp_operation() -> Result<()> {
assert_eq!(results[0][0], 0.36787944f32);
assert_eq!(results[0][1], 1.0f32);
- assert_eq!(results[1], vec![std::f32::consts::E, 7.38905609f32]);
+ assert_eq!(results[1], vec![std::f32::consts::E, 7.389056f32]);
Ok(())
}
@@ -914,7 +888,7 @@ fn test_constant_of_shape() -> Result<()> {
),
_ => panic!("unsupported DType in test"),
};
- let tensor = onnx::TensorProto {
+ let tensor = TensorProto {
data_type: data_type.into(),
dims: tensor.dims().iter().map(|v| *v as i64).collect(),
raw_data: value,
@@ -1293,14 +1267,7 @@ fn test_cos_operation() -> Result<()> {
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
-
- let results = z.to_vec2::<f32>()?;
-
- assert_eq!(
- results,
- vec![vec![1.0, 0.54030234], vec![-0.41614684, -0.9899925]]
- );
-
+ assert_eq!(to_vec2_round(z, 4)?, [[1.0, 0.5403], [-0.4161, -0.99]]);
Ok(())
}
@@ -1342,19 +1309,12 @@ fn test_sin_operation() -> Result<()> {
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
-
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
-
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
-
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
-
- let results = z.to_vec2::<f32>()?;
-
- assert_eq!(results, vec![vec![0.0, 0.841471], vec![0.9092974, 0.14112]]);
-
+ assert_eq!(to_vec2_round(z, 4)?, [[0.0, 0.8415], [0.9093, 0.1411]]);
Ok(())
}
@@ -3150,3 +3110,165 @@ fn test_leakyrelu() -> Result<()> {
Ok(())
}
+
+// "If"
+#[test]
+fn test_if() -> Result<()> {
+ let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
+ let y = vec![5.0, 4.0, 3.0, 2.0, 1.0];
+ let output_type_proto = Some(TypeProto {
+ value: Some(type_proto::Value::TensorType(type_proto::Tensor {
+ elem_type: DataType::Float.into(),
+ shape: Some(TensorShapeProto {
+ dim: vec![Dimension {
+ denotation: "".to_string(),
+ value: Some(dimension::Value::DimValue(5)),
+ }],
+ }),
+ })),
+ denotation: "".to_string(),
+ });
+ let then_branch = GraphProto {
+ output: vec![ValueInfoProto {
+ name: "then_out".to_string(),
+ r#type: output_type_proto.clone(),
+ doc_string: "".to_string(),
+ }],
+ node: vec![NodeProto {
+ op_type: "Constant".to_string(),
+ input: vec![],
+ output: vec!["then_out".to_string()],
+ attribute: vec![AttributeProto {
+ name: "value".to_string(),
+ r#type: AttributeType::Tensor.into(),
+ t: Some(TensorProto {
+ dims: vec![x.len() as i64],
+ float_data: x.clone(),
+ data_type: DataType::Float.into(),
+ ..TensorProto::default()
+ }),
+ ..AttributeProto::default()
+ }],
+ ..NodeProto::default()
+ }],
+ ..GraphProto::default()
+ };
+ let else_branch = GraphProto {
+ output: vec![ValueInfoProto {
+ name: "else_out".to_string(),
+ r#type: output_type_proto.clone(),
+ doc_string: "".to_string(),
+ }],
+ node: vec![NodeProto {
+ op_type: "Constant".to_string(),
+ input: vec![],
+ output: vec!["else_out".to_string()],
+ attribute: vec![AttributeProto {
+ name: "value".to_string(),
+ r#type: AttributeType::Tensor.into(),
+ t: Some(TensorProto {
+ dims: vec![y.len() as i64],
+ float_data: y.clone(),
+ data_type: DataType::Float.into(),
+ ..TensorProto::default()
+ }),
+ ..AttributeProto::default()
+ }],
+ ..NodeProto::default()
+ }],
+ ..GraphProto::default()
+ };
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "If".to_string(),
+ attribute: vec![
+ AttributeProto {
+ name: "then_branch".to_string(),
+ r#type: AttributeType::Graph.into(),
+ g: Some(then_branch),
+ ..AttributeProto::default()
+ },
+ AttributeProto {
+ name: "else_branch".to_string(),
+ r#type: AttributeType::Graph.into(),
+ g: Some(else_branch),
+ ..AttributeProto::default()
+ },
+ ],
+ input: vec!["cond".to_string()],
+ output: vec!["res".to_string()],
+ ..NodeProto::default()
+ }],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: "res".to_string(),
+ doc_string: "".to_string(),
+ r#type: output_type_proto.clone(),
+ }],
+ ..GraphProto::default()
+ }));
+
+ for cond in [1u8, 0] {
+ let inputs =
+ HashMap::from_iter([("cond".to_string(), Tensor::full(cond, (1,), &Device::Cpu)?)]);
+ let outputs = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ let expected = if cond != 0 { &x } else { &y };
+ let Some(res) = outputs.get("res") else {
+ candle::bail!("outputs didn't contain expected key `res`: {outputs:?}");
+ };
+ assert_eq!(&res.to_vec1::<f32>()?, expected);
+ }
+ Ok(())
+}
+
+#[test]
+fn test_pad() -> Result<()> {
+ let data = Tensor::from_vec(vec![1.0, 1.2, 2.3, 3.4, 4.5, 5.7], (3, 2), &Device::Cpu)?;
+ let pads = Tensor::from_vec(vec![0i64, 2, 0, 0], (4,), &Device::Cpu)?;
+ let mode = "reflect";
+
+ let expected = Tensor::from_vec(
+ vec![1.0, 1.2, 1.0, 1.2, 2.3, 3.4, 2.3, 3.4, 4.5, 5.7, 4.5, 5.7],
+ (3, 4),
+ &Device::Cpu,
+ )?;
+
+ let model = create_model_proto_with_graph(Some(GraphProto {
+ input: vec![
+ ValueInfoProto {
+ name: "data".to_string(),
+ ..ValueInfoProto::default()
+ },
+ ValueInfoProto {
+ name: "pads".to_string(),
+ ..ValueInfoProto::default()
+ },
+ ],
+ output: vec![ValueInfoProto {
+ name: "output".to_string(),
+ ..ValueInfoProto::default()
+ }],
+ node: vec![NodeProto {
+ op_type: "Pad".to_string(),
+ input: vec!["data".to_string(), "pads".to_string()],
+ output: vec!["output".to_string()],
+ attribute: vec![AttributeProto {
+ name: "mode".to_string(),
+ r#type: AttributeType::String.into(),
+ s: mode.as_bytes().to_vec(),
+ ..AttributeProto::default()
+ }],
+ ..NodeProto::default()
+ }],
+ ..GraphProto::default()
+ }));
+
+ let inputs = HashMap::from_iter([("data".to_string(), data), ("pads".to_string(), pads)]);
+ let res = candle_onnx::simple_eval(&model, inputs)?;
+ let Some(actual) = res.get("output") else {
+ candle::bail!("outputs didn't contain expected key `output`: {res:?}");
+ };
+
+ assert_eq!(actual.to_vec2::<f64>()?, expected.to_vec2::<f64>()?);
+ Ok(())
+}