diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-11-04 06:36:05 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-04 06:36:05 +0100 |
commit | 8cbb9d0e6ce57a8dbfc685f3121ed9d01b02726e (patch) | |
tree | a49ad8154b547caa83065089bbca9066d981f03e /candle-onnx/src/eval.rs | |
parent | bfe95115c6c55f90a4aa8712664259b5623e2935 (diff) | |
download | candle-8cbb9d0e6ce57a8dbfc685f3121ed9d01b02726e.tar.gz candle-8cbb9d0e6ce57a8dbfc685f3121ed9d01b02726e.tar.bz2 candle-8cbb9d0e6ce57a8dbfc685f3121ed9d01b02726e.zip |
Add some preliminary ONNX support (#1260)
* Add the onnx protos.
* Move the reading bits.
* Install protoc on the CI.
* Install protoc on the cuda CI too.
* Use clap for the onnx tool.
* Tweak the CI protoc install.
* Add some simple evalution function.
* Add some binary operator support.
Diffstat (limited to 'candle-onnx/src/eval.rs')
-rw-r--r-- | candle-onnx/src/eval.rs | 81 |
1 files changed, 81 insertions, 0 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs new file mode 100644 index 00000000..fe112fdd --- /dev/null +++ b/candle-onnx/src/eval.rs @@ -0,0 +1,81 @@ +use crate::onnx; +use candle::{Result, Tensor}; +use std::collections::HashMap; + +pub type Value = Tensor; + +// This function provides a direct evaluation of the proto. +// Longer-term, we should first convert the proto to an intermediate representation of the compute +// graph so as to make multiple evaluations more efficient. +// An example upside of this would be to remove intermediary values when they are not needed +// anymore. +pub fn simple_eval( + model: &onnx::ModelProto, + inputs: HashMap<String, Value>, +) -> Result<HashMap<String, Value>> { + let graph = match &model.graph { + None => candle::bail!("no graph defined in proto"), + Some(graph) => graph, + }; + // TODO: validate the inputs. + let mut values = inputs; + // The nodes are topologically sorted so we can just process them in order. + for node in graph.node.iter() { + let get = |input_name: &str| match values.get(input_name) { + Some(value) => Ok(value), + None => candle::bail!("cannot find {input_name} for op {}", node.name), + }; + // TODO: Validate node.input for each operator. + match node.op_type.as_str() { + "Add" => { + let input0 = get(&node.input[0])?; + let input1 = get(&node.input[0])?; + let output = input0.broadcast_add(input1)?; + values.insert(node.output[0].clone(), output); + } + "Sub" => { + let input0 = get(&node.input[0])?; + let input1 = get(&node.input[0])?; + let output = input0.broadcast_sub(input1)?; + values.insert(node.output[0].clone(), output); + } + "Mul" => { + let input0 = get(&node.input[0])?; + let input1 = get(&node.input[0])?; + let output = input0.broadcast_mul(input1)?; + values.insert(node.output[0].clone(), output); + } + "Div" => { + let input0 = get(&node.input[0])?; + let input1 = get(&node.input[0])?; + let output = input0.broadcast_div(input1)?; + values.insert(node.output[0].clone(), output); + } + "MatMul" => { + let input0 = get(&node.input[0])?; + let input1 = get(&node.input[0])?; + let output = input0.broadcast_matmul(input1)?; + values.insert(node.output[0].clone(), output); + } + "Gelu" => { + let input = get(&node.input[0])?; + let output = input.gelu_erf()?; + values.insert(node.output[0].clone(), output); + } + "Relu" => { + let input = get(&node.input[0])?; + let output = input.relu()?; + values.insert(node.output[0].clone(), output); + } + op_type => candle::bail!("unsupported op_type {op_type} for op {}", node.name), + } + } + graph + .output + .iter() + .map(|output| match values.remove(&output.name) { + None => candle::bail!("cannot find output {}", output.name), + Some(value) => Ok((output.name.clone(), value)), + }) + .collect() +} |