1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
|
use crate::onnx;
use candle::{Result, Tensor};
use std::collections::HashMap;
pub type Value = Tensor;
// This function provides a direct evaluation of the proto.
// Longer-term, we should first convert the proto to an intermediate representation of the compute
// graph so as to make multiple evaluations more efficient.
// An example upside of this would be to remove intermediary values when they are not needed
// anymore.
pub fn simple_eval(
model: &onnx::ModelProto,
inputs: HashMap<String, Value>,
) -> Result<HashMap<String, Value>> {
let graph = match &model.graph {
None => candle::bail!("no graph defined in proto"),
Some(graph) => graph,
};
// TODO: validate the inputs.
let mut values = inputs;
// The nodes are topologically sorted so we can just process them in order.
for node in graph.node.iter() {
let get = |input_name: &str| match values.get(input_name) {
Some(value) => Ok(value),
None => candle::bail!("cannot find {input_name} for op {}", node.name),
};
// TODO: Validate node.input for each operator.
match node.op_type.as_str() {
"Add" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[0])?;
let output = input0.broadcast_add(input1)?;
values.insert(node.output[0].clone(), output);
}
"Sub" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[0])?;
let output = input0.broadcast_sub(input1)?;
values.insert(node.output[0].clone(), output);
}
"Mul" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[0])?;
let output = input0.broadcast_mul(input1)?;
values.insert(node.output[0].clone(), output);
}
"Div" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[0])?;
let output = input0.broadcast_div(input1)?;
values.insert(node.output[0].clone(), output);
}
"MatMul" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[0])?;
let output = input0.broadcast_matmul(input1)?;
values.insert(node.output[0].clone(), output);
}
"Gelu" => {
let input = get(&node.input[0])?;
let output = input.gelu_erf()?;
values.insert(node.output[0].clone(), output);
}
"Relu" => {
let input = get(&node.input[0])?;
let output = input.relu()?;
values.insert(node.output[0].clone(), output);
}
op_type => candle::bail!("unsupported op_type {op_type} for op {}", node.name),
}
}
graph
.output
.iter()
.map(|output| match values.remove(&output.name) {
None => candle::bail!("cannot find output {}", output.name),
Some(value) => Ok((output.name.clone(), value)),
})
.collect()
}
|