diff options
Diffstat (limited to 'candle-onnx/src/eval.rs')
-rw-r--r-- | candle-onnx/src/eval.rs | 101 |
1 files changed, 98 insertions, 3 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index e72002e6..f52e6c5c 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -1,6 +1,6 @@ -use crate::onnx; use crate::onnx::attribute_proto::AttributeType; use crate::onnx::tensor_proto::DataType; +use crate::onnx::{self, GraphProto}; use candle::{bail, DType, Device, Result, Tensor}; use std::{collections::HashMap, usize}; @@ -56,6 +56,15 @@ impl Attr for str { } } +impl Attr for GraphProto { + const TYPE: AttributeType = AttributeType::Graph; + fn get(attr: &onnx::AttributeProto) -> Result<&Self> { + attr.g + .as_ref() + .ok_or_else(|| candle::Error::Msg("attribute does not contain graph".to_string())) + } +} + impl AttrOwned for Tensor { const TYPE: AttributeType = AttributeType::Tensor; fn get(attr: &onnx::AttributeProto) -> Result<Self> { @@ -214,13 +223,19 @@ pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> { // anymore. pub fn simple_eval( model: &onnx::ModelProto, - inputs: HashMap<String, Value>, + mut inputs: HashMap<String, Value>, ) -> Result<HashMap<String, Value>> { let graph = match &model.graph { None => bail!("no graph defined in proto"), Some(graph) => graph, }; - let mut values = inputs; + simple_eval_(graph, &mut inputs) +} + +fn simple_eval_( + graph: &onnx::GraphProto, + values: &mut HashMap<String, Value>, +) -> Result<HashMap<String, Value>> { for t in graph.initializer.iter() { let tensor = get_tensor(t, t.name.as_str())?; values.insert(t.name.to_string(), tensor); @@ -958,6 +973,86 @@ pub fn simple_eval( let input = get(&node.input[0])?; values.insert(node.output[0].clone(), input.clone()); } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#if + "If" => { + // protobuf encodes boolean false as 0 and true as 1 + let cond = get(&node.input[0])?.get(0)?.to_scalar::<u8>()?; + let attr_name = if cond != 0 { + "then_branch" + } else { + "else_branch" + }; + let sub_graph = get_attr::<GraphProto>(node, attr_name)?; + if sub_graph.output.len() != node.output.len() { + bail!( + "If node {:?} is malformed: branch outputs ({}) don't match node outputs ({})", + node.name, + sub_graph.output.len(), + node.output.len() + ); + } + let branch_out = simple_eval_(sub_graph, values)?; + for (i, out) in node.output.iter().enumerate() { + values.insert( + out.clone(), + branch_out.get(&sub_graph.output[i].name).unwrap().clone(), + ); + } + } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#pad + "Pad" => { + let mode = get_attr_opt(node, "mode")?.unwrap_or("constant"); + let data = get(&node.input[0])?; + let pads = get(&node.input[1])?; + if node.input.len() > 2 { + bail!( + "unsupported number of inputs {} for Pad node {:?}, expected 2", + node.input.len(), + node.name + ); + } + if pads.rank() != 1 { + bail!("Pad expects 'pads' input to be 1D vector: {pads:?}"); + } + if pads.dim(0).unwrap() != 2 * data.rank() { + bail!("Pad expects 'pads' input len to be 2 * rank of 'data' input: pads: {}, data rank: {}", pads, data.rank()); + } + + let pads = pads.to_vec1::<i64>()?; + let (pads_pre, pads_post) = pads.split_at(pads.len() / 2); + + match mode { + "reflect" => { + let mut out = data.clone(); + for (i, &dim) in data.dims().iter().enumerate().rev() { + if pads_pre[i] == 0 && pads_post[i] == 0 { + continue; + } + fn zigzag(min: i64, max: i64) -> impl Iterator<Item = i64> { + std::iter::repeat((min..max).chain((min + 1..=max).rev())).flatten() + } + let idx = if dim > 1 { + let cycle_len = dim * 2 - 1; + let skip = (pads_pre[i] as usize) % cycle_len; + let idx = zigzag(0, (dim - 1) as i64) + .skip(skip) + .take((pads_pre[i] as usize) + dim + (pads_post[i] as usize)); + Tensor::from_iter(idx, out.device())? + } else { + Tensor::full(0i64, (dim,), out.device())? + }; + + out = out.index_select(&idx, i)?; + } + + values.insert(node.output[0].clone(), out); + } + _ => bail!( + "unsupported 'mode' value {mode:?} for Pad node {:?}", + node.name + ), + } + } // https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13 // TODO: This version is only compatible with ReduceMean V13 and below. "ReduceMean" => { |