summaryrefslogtreecommitdiff
path: root/candle-onnx/src/eval.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-onnx/src/eval.rs')
-rw-r--r--candle-onnx/src/eval.rs101
1 files changed, 98 insertions, 3 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index e72002e6..f52e6c5c 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -1,6 +1,6 @@
-use crate::onnx;
use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
+use crate::onnx::{self, GraphProto};
use candle::{bail, DType, Device, Result, Tensor};
use std::{collections::HashMap, usize};
@@ -56,6 +56,15 @@ impl Attr for str {
}
}
+impl Attr for GraphProto {
+ const TYPE: AttributeType = AttributeType::Graph;
+ fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
+ attr.g
+ .as_ref()
+ .ok_or_else(|| candle::Error::Msg("attribute does not contain graph".to_string()))
+ }
+}
+
impl AttrOwned for Tensor {
const TYPE: AttributeType = AttributeType::Tensor;
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
@@ -214,13 +223,19 @@ pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
// anymore.
pub fn simple_eval(
model: &onnx::ModelProto,
- inputs: HashMap<String, Value>,
+ mut inputs: HashMap<String, Value>,
) -> Result<HashMap<String, Value>> {
let graph = match &model.graph {
None => bail!("no graph defined in proto"),
Some(graph) => graph,
};
- let mut values = inputs;
+ simple_eval_(graph, &mut inputs)
+}
+
+fn simple_eval_(
+ graph: &onnx::GraphProto,
+ values: &mut HashMap<String, Value>,
+) -> Result<HashMap<String, Value>> {
for t in graph.initializer.iter() {
let tensor = get_tensor(t, t.name.as_str())?;
values.insert(t.name.to_string(), tensor);
@@ -958,6 +973,86 @@ pub fn simple_eval(
let input = get(&node.input[0])?;
values.insert(node.output[0].clone(), input.clone());
}
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#if
+ "If" => {
+ // protobuf encodes boolean false as 0 and true as 1
+ let cond = get(&node.input[0])?.get(0)?.to_scalar::<u8>()?;
+ let attr_name = if cond != 0 {
+ "then_branch"
+ } else {
+ "else_branch"
+ };
+ let sub_graph = get_attr::<GraphProto>(node, attr_name)?;
+ if sub_graph.output.len() != node.output.len() {
+ bail!(
+ "If node {:?} is malformed: branch outputs ({}) don't match node outputs ({})",
+ node.name,
+ sub_graph.output.len(),
+ node.output.len()
+ );
+ }
+ let branch_out = simple_eval_(sub_graph, values)?;
+ for (i, out) in node.output.iter().enumerate() {
+ values.insert(
+ out.clone(),
+ branch_out.get(&sub_graph.output[i].name).unwrap().clone(),
+ );
+ }
+ }
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#pad
+ "Pad" => {
+ let mode = get_attr_opt(node, "mode")?.unwrap_or("constant");
+ let data = get(&node.input[0])?;
+ let pads = get(&node.input[1])?;
+ if node.input.len() > 2 {
+ bail!(
+ "unsupported number of inputs {} for Pad node {:?}, expected 2",
+ node.input.len(),
+ node.name
+ );
+ }
+ if pads.rank() != 1 {
+ bail!("Pad expects 'pads' input to be 1D vector: {pads:?}");
+ }
+ if pads.dim(0).unwrap() != 2 * data.rank() {
+ bail!("Pad expects 'pads' input len to be 2 * rank of 'data' input: pads: {}, data rank: {}", pads, data.rank());
+ }
+
+ let pads = pads.to_vec1::<i64>()?;
+ let (pads_pre, pads_post) = pads.split_at(pads.len() / 2);
+
+ match mode {
+ "reflect" => {
+ let mut out = data.clone();
+ for (i, &dim) in data.dims().iter().enumerate().rev() {
+ if pads_pre[i] == 0 && pads_post[i] == 0 {
+ continue;
+ }
+ fn zigzag(min: i64, max: i64) -> impl Iterator<Item = i64> {
+ std::iter::repeat((min..max).chain((min + 1..=max).rev())).flatten()
+ }
+ let idx = if dim > 1 {
+ let cycle_len = dim * 2 - 1;
+ let skip = (pads_pre[i] as usize) % cycle_len;
+ let idx = zigzag(0, (dim - 1) as i64)
+ .skip(skip)
+ .take((pads_pre[i] as usize) + dim + (pads_post[i] as usize));
+ Tensor::from_iter(idx, out.device())?
+ } else {
+ Tensor::full(0i64, (dim,), out.device())?
+ };
+
+ out = out.index_select(&idx, i)?;
+ }
+
+ values.insert(node.output[0].clone(), out);
+ }
+ _ => bail!(
+ "unsupported 'mode' value {mode:?} for Pad node {:?}",
+ node.name
+ ),
+ }
+ }
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
// TODO: This version is only compatible with ReduceMean V13 and below.
"ReduceMean" => {