summaryrefslogtreecommitdiff
path: root/candle-onnx/src/eval.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-11-04 06:36:05 +0100
committerGitHub <noreply@github.com>2023-11-04 06:36:05 +0100
commit8cbb9d0e6ce57a8dbfc685f3121ed9d01b02726e (patch)
treea49ad8154b547caa83065089bbca9066d981f03e /candle-onnx/src/eval.rs
parentbfe95115c6c55f90a4aa8712664259b5623e2935 (diff)
downloadcandle-8cbb9d0e6ce57a8dbfc685f3121ed9d01b02726e.tar.gz
candle-8cbb9d0e6ce57a8dbfc685f3121ed9d01b02726e.tar.bz2
candle-8cbb9d0e6ce57a8dbfc685f3121ed9d01b02726e.zip
Add some preliminary ONNX support (#1260)
* Add the onnx protos. * Move the reading bits. * Install protoc on the CI. * Install protoc on the cuda CI too. * Use clap for the onnx tool. * Tweak the CI protoc install. * Add some simple evalution function. * Add some binary operator support.
Diffstat (limited to 'candle-onnx/src/eval.rs')
-rw-r--r--candle-onnx/src/eval.rs81
1 files changed, 81 insertions, 0 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
new file mode 100644
index 00000000..fe112fdd
--- /dev/null
+++ b/candle-onnx/src/eval.rs
@@ -0,0 +1,81 @@
+use crate::onnx;
+use candle::{Result, Tensor};
+use std::collections::HashMap;
+
+pub type Value = Tensor;
+
+// This function provides a direct evaluation of the proto.
+// Longer-term, we should first convert the proto to an intermediate representation of the compute
+// graph so as to make multiple evaluations more efficient.
+// An example upside of this would be to remove intermediary values when they are not needed
+// anymore.
+pub fn simple_eval(
+ model: &onnx::ModelProto,
+ inputs: HashMap<String, Value>,
+) -> Result<HashMap<String, Value>> {
+ let graph = match &model.graph {
+ None => candle::bail!("no graph defined in proto"),
+ Some(graph) => graph,
+ };
+ // TODO: validate the inputs.
+ let mut values = inputs;
+ // The nodes are topologically sorted so we can just process them in order.
+ for node in graph.node.iter() {
+ let get = |input_name: &str| match values.get(input_name) {
+ Some(value) => Ok(value),
+ None => candle::bail!("cannot find {input_name} for op {}", node.name),
+ };
+ // TODO: Validate node.input for each operator.
+ match node.op_type.as_str() {
+ "Add" => {
+ let input0 = get(&node.input[0])?;
+ let input1 = get(&node.input[0])?;
+ let output = input0.broadcast_add(input1)?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "Sub" => {
+ let input0 = get(&node.input[0])?;
+ let input1 = get(&node.input[0])?;
+ let output = input0.broadcast_sub(input1)?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "Mul" => {
+ let input0 = get(&node.input[0])?;
+ let input1 = get(&node.input[0])?;
+ let output = input0.broadcast_mul(input1)?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "Div" => {
+ let input0 = get(&node.input[0])?;
+ let input1 = get(&node.input[0])?;
+ let output = input0.broadcast_div(input1)?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "MatMul" => {
+ let input0 = get(&node.input[0])?;
+ let input1 = get(&node.input[0])?;
+ let output = input0.broadcast_matmul(input1)?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "Gelu" => {
+ let input = get(&node.input[0])?;
+ let output = input.gelu_erf()?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "Relu" => {
+ let input = get(&node.input[0])?;
+ let output = input.relu()?;
+ values.insert(node.output[0].clone(), output);
+ }
+ op_type => candle::bail!("unsupported op_type {op_type} for op {}", node.name),
+ }
+ }
+ graph
+ .output
+ .iter()
+ .map(|output| match values.remove(&output.name) {
+ None => candle::bail!("cannot find output {}", output.name),
+ Some(value) => Ok((output.name.clone(), value)),
+ })
+ .collect()
+}