summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-11-04 06:36:05 +0100
committerGitHub <noreply@github.com>2023-11-04 06:36:05 +0100
commit8cbb9d0e6ce57a8dbfc685f3121ed9d01b02726e (patch)
treea49ad8154b547caa83065089bbca9066d981f03e
parentbfe95115c6c55f90a4aa8712664259b5623e2935 (diff)
downloadcandle-8cbb9d0e6ce57a8dbfc685f3121ed9d01b02726e.tar.gz
candle-8cbb9d0e6ce57a8dbfc685f3121ed9d01b02726e.tar.bz2
candle-8cbb9d0e6ce57a8dbfc685f3121ed9d01b02726e.zip
Add some preliminary ONNX support (#1260)
* Add the onnx protos. * Move the reading bits. * Install protoc on the CI. * Install protoc on the cuda CI too. * Use clap for the onnx tool. * Tweak the CI protoc install. * Add some simple evalution function. * Add some binary operator support.
-rw-r--r--.github/workflows/ci_cuda.yaml2
-rw-r--r--.github/workflows/rust-ci.yml4
-rw-r--r--Cargo.toml1
-rw-r--r--candle-onnx/Cargo.toml22
-rw-r--r--candle-onnx/build.rs6
-rw-r--r--candle-onnx/examples/onnx_basics.rs56
-rw-r--r--candle-onnx/src/eval.rs81
-rw-r--r--candle-onnx/src/lib.rs14
-rw-r--r--candle-onnx/src/onnx.proto3836
-rw-r--r--test.onnx12
10 files changed, 1033 insertions, 1 deletions
diff --git a/.github/workflows/ci_cuda.yaml b/.github/workflows/ci_cuda.yaml
index 8953c444..ec792a25 100644
--- a/.github/workflows/ci_cuda.yaml
+++ b/.github/workflows/ci_cuda.yaml
@@ -59,7 +59,7 @@ jobs:
- name: Install Rust Stable
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
- uses: Swatinem/rust-cache@v2
- - run: apt-get update -y && apt-get install libssl-dev -y
+ - run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
- name: Test (cuda)
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
stop-runner:
diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml
index 2ca53b23..b435bdfa 100644
--- a/.github/workflows/rust-ci.yml
+++ b/.github/workflows/rust-ci.yml
@@ -16,6 +16,7 @@ jobs:
rust: [stable]
steps:
- uses: actions/checkout@v2
+ - uses: arduino/setup-protoc@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
@@ -35,6 +36,7 @@ jobs:
rust: [stable]
steps:
- uses: actions/checkout@v2
+ - uses: arduino/setup-protoc@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
@@ -50,6 +52,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
+ - uses: arduino/setup-protoc@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
@@ -66,6 +69,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
+ - uses: arduino/setup-protoc@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
diff --git a/Cargo.toml b/Cargo.toml
index 89ffe530..5a7ea759 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -5,6 +5,7 @@ members = [
"candle-examples",
"candle-book",
"candle-nn",
+ "candle-onnx",
"candle-pyo3",
"candle-transformers",
"candle-wasm-examples/*",
diff --git a/candle-onnx/Cargo.toml b/candle-onnx/Cargo.toml
new file mode 100644
index 00000000..a4817e43
--- /dev/null
+++ b/candle-onnx/Cargo.toml
@@ -0,0 +1,22 @@
+[package]
+name = "candle-onnx"
+version.workspace = true
+edition.workspace = true
+description.workspace = true
+repository.workspace = true
+keywords.workspace = true
+categories.workspace = true
+license.workspace = true
+
+[dependencies]
+candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
+candle-nn = { path = "../candle-nn", version = "0.3.0" }
+prost = "0.12.1"
+
+[build-dependencies]
+prost-build = "0.12.1"
+
+[dev-dependencies]
+anyhow = { workspace = true }
+clap = { workspace = true }
+
diff --git a/candle-onnx/build.rs b/candle-onnx/build.rs
new file mode 100644
index 00000000..79e7a39d
--- /dev/null
+++ b/candle-onnx/build.rs
@@ -0,0 +1,6 @@
+use std::io::Result;
+
+fn main() -> Result<()> {
+ prost_build::compile_protos(&["src/onnx.proto3"], &["src/"])?;
+ Ok(())
+}
diff --git a/candle-onnx/examples/onnx_basics.rs b/candle-onnx/examples/onnx_basics.rs
new file mode 100644
index 00000000..b91cbee6
--- /dev/null
+++ b/candle-onnx/examples/onnx_basics.rs
@@ -0,0 +1,56 @@
+use anyhow::Result;
+use candle::{Device, Tensor};
+
+use clap::{Parser, Subcommand};
+
+#[derive(Subcommand, Debug, Clone)]
+enum Command {
+ Print {
+ #[arg(long)]
+ file: String,
+ },
+ SimpleEval {
+ #[arg(long)]
+ file: String,
+ },
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+pub struct Args {
+ #[command(subcommand)]
+ command: Command,
+}
+
+pub fn main() -> Result<()> {
+ let args = Args::parse();
+ match args.command {
+ Command::Print { file } => {
+ let model = candle_onnx::read_file(file)?;
+ println!("{model:?}");
+ let graph = model.graph.unwrap();
+ for node in graph.node.iter() {
+ println!("{node:?}");
+ }
+ }
+ Command::SimpleEval { file } => {
+ let model = candle_onnx::read_file(file)?;
+ let inputs = model
+ .graph
+ .as_ref()
+ .unwrap()
+ .input
+ .iter()
+ .map(|name| {
+ let value = Tensor::new(&[-3.2, 2.7], &Device::Cpu)?;
+ Ok((name.name.clone(), value))
+ })
+ .collect::<Result<_>>()?;
+ let outputs = candle_onnx::simple_eval(&model, inputs)?;
+ for (name, value) in outputs.iter() {
+ println!("{name}: {value:?}")
+ }
+ }
+ }
+ Ok(())
+}
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
new file mode 100644
index 00000000..fe112fdd
--- /dev/null
+++ b/candle-onnx/src/eval.rs
@@ -0,0 +1,81 @@
+use crate::onnx;
+use candle::{Result, Tensor};
+use std::collections::HashMap;
+
+pub type Value = Tensor;
+
+// This function provides a direct evaluation of the proto.
+// Longer-term, we should first convert the proto to an intermediate representation of the compute
+// graph so as to make multiple evaluations more efficient.
+// An example upside of this would be to remove intermediary values when they are not needed
+// anymore.
+pub fn simple_eval(
+ model: &onnx::ModelProto,
+ inputs: HashMap<String, Value>,
+) -> Result<HashMap<String, Value>> {
+ let graph = match &model.graph {
+ None => candle::bail!("no graph defined in proto"),
+ Some(graph) => graph,
+ };
+ // TODO: validate the inputs.
+ let mut values = inputs;
+ // The nodes are topologically sorted so we can just process them in order.
+ for node in graph.node.iter() {
+ let get = |input_name: &str| match values.get(input_name) {
+ Some(value) => Ok(value),
+ None => candle::bail!("cannot find {input_name} for op {}", node.name),
+ };
+ // TODO: Validate node.input for each operator.
+ match node.op_type.as_str() {
+ "Add" => {
+ let input0 = get(&node.input[0])?;
+ let input1 = get(&node.input[0])?;
+ let output = input0.broadcast_add(input1)?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "Sub" => {
+ let input0 = get(&node.input[0])?;
+ let input1 = get(&node.input[0])?;
+ let output = input0.broadcast_sub(input1)?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "Mul" => {
+ let input0 = get(&node.input[0])?;
+ let input1 = get(&node.input[0])?;
+ let output = input0.broadcast_mul(input1)?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "Div" => {
+ let input0 = get(&node.input[0])?;
+ let input1 = get(&node.input[0])?;
+ let output = input0.broadcast_div(input1)?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "MatMul" => {
+ let input0 = get(&node.input[0])?;
+ let input1 = get(&node.input[0])?;
+ let output = input0.broadcast_matmul(input1)?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "Gelu" => {
+ let input = get(&node.input[0])?;
+ let output = input.gelu_erf()?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "Relu" => {
+ let input = get(&node.input[0])?;
+ let output = input.relu()?;
+ values.insert(node.output[0].clone(), output);
+ }
+ op_type => candle::bail!("unsupported op_type {op_type} for op {}", node.name),
+ }
+ }
+ graph
+ .output
+ .iter()
+ .map(|output| match values.remove(&output.name) {
+ None => candle::bail!("cannot find output {}", output.name),
+ Some(value) => Ok((output.name.clone(), value)),
+ })
+ .collect()
+}
diff --git a/candle-onnx/src/lib.rs b/candle-onnx/src/lib.rs
new file mode 100644
index 00000000..3b36c4cf
--- /dev/null
+++ b/candle-onnx/src/lib.rs
@@ -0,0 +1,14 @@
+use candle::Result;
+use prost::Message;
+
+pub mod onnx {
+ include!(concat!(env!("OUT_DIR"), "/onnx.rs"));
+}
+
+mod eval;
+pub use eval::simple_eval;
+
+pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
+ let buf = std::fs::read(p)?;
+ onnx::ModelProto::decode(buf.as_slice()).map_err(candle::Error::wrap)
+}
diff --git a/candle-onnx/src/onnx.proto3 b/candle-onnx/src/onnx.proto3
new file mode 100644
index 00000000..f47006f8
--- /dev/null
+++ b/candle-onnx/src/onnx.proto3
@@ -0,0 +1,836 @@
+//
+// WARNING: This file is automatically generated! Please edit onnx.in.proto.
+//
+
+
+// SPDX-License-Identifier: Apache-2.0
+
+
+syntax = "proto3";
+
+package onnx;
+
+// Overview
+//
+// ONNX is an open specification that is comprised of the following components:
+//
+// 1) A definition of an extensible computation graph model.
+// 2) Definitions of standard data types.
+// 3) Definitions of built-in operators.
+//
+// This document describes the syntax of models and their computation graphs,
+// as well as the standard data types. Together, they are referred to as the ONNX
+// Intermediate Representation, or 'IR' for short.
+//
+// The normative semantic specification of the ONNX IR is found in docs/IR.md.
+// Definitions of the built-in neural network operators may be found in docs/Operators.md.
+
+// Notes
+//
+// Protobuf compatibility
+//
+// To simplify framework compatibility, ONNX is defined using the subset of protobuf
+// that is compatible with both protobuf v2 and v3. This means that we do not use any
+// protobuf features that are only available in one of the two versions.
+//
+// Here are the most notable contortions we have to carry out to work around
+// these limitations:
+//
+// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
+// of key-value pairs, where order does not matter and duplicates
+// are not allowed.
+
+
+// Versioning
+//
+// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
+//
+// To be compatible with both proto2 and proto3, we will use a version number
+// that is not defined by the default value but an explicit enum number.
+enum Version {
+ // proto3 requires the first enum value to be zero.
+ // We add this just to appease the compiler.
+ _START_VERSION = 0;
+ // The version field is always serialized and we will use it to store the
+ // version that the graph is generated from. This helps us set up version
+ // control.
+ // For the IR, we are using simple numbers starting with 0x00000001,
+ // which was the version we published on Oct 10, 2017.
+ IR_VERSION_2017_10_10 = 0x0000000000000001;
+
+ // IR_VERSION 2 published on Oct 30, 2017
+ // - Added type discriminator to AttributeProto to support proto3 users
+ IR_VERSION_2017_10_30 = 0x0000000000000002;
+
+ // IR VERSION 3 published on Nov 3, 2017
+ // - For operator versioning:
+ // - Added new message OperatorSetIdProto
+ // - Added opset_import in ModelProto
+ // - For vendor extensions, added domain in NodeProto
+ IR_VERSION_2017_11_3 = 0x0000000000000003;
+
+ // IR VERSION 4 published on Jan 22, 2019
+ // - Relax constraint that initializers should be a subset of graph inputs
+ // - Add type BFLOAT16
+ IR_VERSION_2019_1_22 = 0x0000000000000004;
+
+ // IR VERSION 5 published on March 18, 2019
+ // - Add message TensorAnnotation.
+ // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters.
+ IR_VERSION_2019_3_18 = 0x0000000000000005;
+
+ // IR VERSION 6 published on Sep 19, 2019
+ // - Add support for sparse tensor constants stored in model.
+ // - Add message SparseTensorProto
+ // - Add sparse initializers
+ IR_VERSION_2019_9_19 = 0x0000000000000006;
+
+ // IR VERSION 7 published on May 8, 2020
+ // - Add support to allow function body graph to rely on multiple external opreator sets.
+ // - Add a list to promote inference graph's initializers to global and
+ // mutable variables. Global variables are visible in all graphs of the
+ // stored models.
+ // - Add message TrainingInfoProto to store initialization
+ // method and training algorithm. The execution of TrainingInfoProto
+ // can modify the values of mutable variables.
+ // - Implicitly add inference graph into each TrainingInfoProto's algorithm.
+ IR_VERSION_2020_5_8 = 0x0000000000000007;
+
+ // IR VERSION 8 published on July 30, 2021
+ // Introduce TypeProto.SparseTensor
+ // Introduce TypeProto.Optional
+ // Added a list of FunctionProtos local to the model
+ // Deprecated since_version and operator status from FunctionProto
+ IR_VERSION_2021_7_30 = 0x0000000000000008;
+
+ // IR VERSION 9 published on May 5, 2023
+ // Added AttributeProto to FunctionProto so that default attribute values can be set.
+ // Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
+ IR_VERSION = 0x0000000000000009;
+}
+
+// Attributes
+//
+// A named attribute containing either singular float, integer, string, graph,
+// and tensor values, or repeated float, integer, string, graph, and tensor values.
+// An AttributeProto MUST contain the name field, and *only one* of the
+// following content fields, effectively enforcing a C/C++ union equivalent.
+message AttributeProto {
+ reserved 12, 16 to 19;
+ reserved "v";
+
+ // Note: this enum is structurally identical to the OpSchema::AttrType
+ // enum defined in schema.h. If you rev one, you likely need to rev the other.
+ enum AttributeType {
+ UNDEFINED = 0;
+ FLOAT = 1;
+ INT = 2;
+ STRING = 3;
+ TENSOR = 4;
+ GRAPH = 5;
+ SPARSE_TENSOR = 11;
+ TYPE_PROTO = 13;
+
+ FLOATS = 6;
+ INTS = 7;
+ STRINGS = 8;
+ TENSORS = 9;
+ GRAPHS = 10;
+ SPARSE_TENSORS = 12;
+ TYPE_PROTOS = 14;
+ }
+
+ // The name field MUST be present for this version of the IR.
+ string name = 1; // namespace Attribute
+
+ // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
+ // In this case, this AttributeProto does not contain data, and it's a reference of attribute
+ // in parent scope.
+ // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
+ string ref_attr_name = 21;
+
+ // A human-readable documentation for this attribute. Markdown is allowed.
+ string doc_string = 13;
+
+ // The type field MUST be present for this version of the IR.
+ // For 0.0.1 versions of the IR, this field was not defined, and
+ // implementations needed to use has_field heuristics to determine
+ // which value field was in use. For IR_VERSION 0.0.2 or later, this
+ // field MUST be set and match the f|i|s|t|... field in use. This
+ // change was made to accommodate proto3 implementations.
+ AttributeType type = 20; // discriminator that indicates which field below is in use
+
+ // Exactly ONE of the following fields must be present for this version of the IR
+ float f = 2; // float
+ int64 i = 3; // int
+ bytes s = 4; // UTF-8 string
+ TensorProto t = 5; // tensor value
+ GraphProto g = 6; // graph
+ SparseTensorProto sparse_tensor = 22; // sparse tensor value
+ // Do not use field below, it's deprecated.
+ // optional ValueProto v = 12; // value - subsumes everything but graph
+ TypeProto tp = 14; // type proto
+
+ repeated float floats = 7; // list of floats
+ repeated int64 ints = 8; // list of ints
+ repeated bytes strings = 9; // list of UTF-8 strings
+ repeated TensorProto tensors = 10; // list of tensors
+ repeated GraphProto graphs = 11; // list of graph
+ repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
+ repeated TypeProto type_protos = 15;// list of type protos
+}
+
+// Defines information on value, including the name, the type, and
+// the shape of the value.
+message ValueInfoProto {
+ // This field MUST be present in this version of the IR.
+ string name = 1; // namespace Value
+ // This field MUST be present in this version of the IR for
+ // inputs and outputs of the top-level graph.
+ TypeProto type = 2;
+ // A human-readable documentation for this value. Markdown is allowed.
+ string doc_string = 3;
+}
+
+// Nodes
+//
+// Computation graphs are made up of a DAG of nodes, which represent what is
+// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
+//
+// For example, it can be a node of type "Conv" that takes in an image, a filter
+// tensor and a bias tensor, and produces the convolved output.
+message NodeProto {
+ repeated string input = 1; // namespace Value
+ repeated string output = 2; // namespace Value
+
+ // An optional identifier for this node in a graph.
+ // This field MAY be absent in ths version of the IR.
+ string name = 3; // namespace Node
+
+ // The symbolic identifier of the Operator to execute.
+ string op_type = 4; // namespace Operator
+ // The domain of the OperatorSet that specifies the operator named by op_type.
+ string domain = 7; // namespace Domain
+
+ // Additional named attributes.
+ repeated AttributeProto attribute = 5;
+
+ // A human-readable documentation for this node. Markdown is allowed.
+ string doc_string = 6;
+}
+
+// Training information
+// TrainingInfoProto stores information for training a model.
+// In particular, this defines two functionalities: an initialization-step
+// and a training-algorithm-step. Initialization resets the model
+// back to its original state as if no training has been performed.
+// Training algorithm improves the model based on input data.
+//
+// The semantics of the initialization-step is that the initializers
+// in ModelProto.graph and in TrainingInfoProto.algorithm are first
+// initialized as specified by the initializers in the graph, and then
+// updated by the "initialization_binding" in every instance in
+// ModelProto.training_info.
+//
+// The field "algorithm" defines a computation graph which represents a
+// training algorithm's step. After the execution of a
+// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
+// may be immediately updated. If the targeted training algorithm contains
+// consecutive update steps (such as block coordinate descent methods),
+// the user needs to create a TrainingInfoProto for each step.
+message TrainingInfoProto {
+ // This field describes a graph to compute the initial tensors
+ // upon starting the training process. Initialization graph has no input
+ // and can have multiple outputs. Usually, trainable tensors in neural
+ // networks are randomly initialized. To achieve that, for each tensor,
+ // the user can put a random number operator such as RandomNormal or
+ // RandomUniform in TrainingInfoProto.initialization.node and assign its
+ // random output to the specific tensor using "initialization_binding".
+ // This graph can also set the initializers in "algorithm" in the same
+ // TrainingInfoProto; a use case is resetting the number of training
+ // iteration to zero.
+ //
+ // By default, this field is an empty graph and its evaluation does not
+ // produce any output. Thus, no initializer would be changed by default.
+ GraphProto initialization = 1;
+
+ // This field represents a training algorithm step. Given required inputs,
+ // it computes outputs to update initializers in its own or inference graph's
+ // initializer lists. In general, this field contains loss node, gradient node,
+ // optimizer node, increment of iteration count.
+ //
+ // An execution of the training algorithm step is performed by executing the
+ // graph obtained by combining the inference graph (namely "ModelProto.graph")
+ // and the "algorithm" graph. That is, the actual
+ // input/initializer/output/node/value_info/sparse_initializer list of
+ // the training graph is the concatenation of
+ // "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
+ // and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
+ // in that order. This combined graph must satisfy the normal ONNX conditions.
+ // Now, let's provide a visualization of graph combination for clarity.
+ // Let the inference graph (i.e., "ModelProto.graph") be
+ // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
+ // and the "algorithm" graph be
+ // tensor_d -> Add -> tensor_e
+ // The combination process results
+ // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
+ //
+ // Notice that an input of a node in the "algorithm" graph may reference the
+ // output of a node in the inference graph (but not the other way round). Also, inference
+ // node cannot reference inputs of "algorithm". With these restrictions, inference graph
+ // can always be run independently without training information.
+ //
+ // By default, this field is an empty graph and its evaluation does not
+ // produce any output. Evaluating the default training step never
+ // update any initializers.
+ GraphProto algorithm = 2;
+
+ // This field specifies the bindings from the outputs of "initialization" to
+ // some initializers in "ModelProto.graph.initializer" and
+ // the "algorithm.initializer" in the same TrainingInfoProto.
+ // See "update_binding" below for details.
+ //
+ // By default, this field is empty and no initializer would be changed
+ // by the execution of "initialization".
+ repeated StringStringEntryProto initialization_binding = 3;
+
+ // Gradient-based training is usually an iterative procedure. In one gradient
+ // descent iteration, we apply
+ //
+ // x = x - r * g
+ //
+ // where "x" is the optimized tensor, "r" stands for learning rate, and "g" is
+ // gradient of "x" with respect to a chosen loss. To avoid adding assignments
+ // into the training graph, we split the update equation into
+ //
+ // y = x - r * g
+ // x = y
+ //
+ // The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To
+ // tell that "y" should be assigned to "x", the field "update_binding" may
+ // contain a key-value pair of strings, "x" (key of StringStringEntryProto)
+ // and "y" (value of StringStringEntryProto).
+ // For a neural network with multiple trainable (mutable) tensors, there can
+ // be multiple key-value pairs in "update_binding".
+ //
+ // The initializers appears as keys in "update_binding" are considered
+ // mutable variables. This implies some behaviors
+ // as described below.
+ //
+ // 1. We have only unique keys in all "update_binding"s so that two
+ // variables may not have the same name. This ensures that one
+ // variable is assigned up to once.
+ // 2. The keys must appear in names of "ModelProto.graph.initializer" or
+ // "TrainingInfoProto.algorithm.initializer".
+ // 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
+ // 4. Mutable variables are initialized to the value specified by the
+ // corresponding initializer, and then potentially updated by
+ // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
+ //
+ // This field usually contains names of trainable tensors
+ // (in ModelProto.graph), optimizer states such as momentums in advanced
+ // stochastic gradient methods (in TrainingInfoProto.graph),
+ // and number of training iterations (in TrainingInfoProto.graph).
+ //
+ // By default, this field is empty and no initializer would be changed
+ // by the execution of "algorithm".
+ repeated StringStringEntryProto update_binding = 4;
+}
+
+// Models
+//
+// ModelProto is a top-level file/container format for bundling a ML model and
+// associating its computation graph with metadata.
+//
+// The semantics of the model are described by the associated GraphProto's.
+message ModelProto {
+ // The version of the IR this model targets. See Version enum above.
+ // This field MUST be present.
+ int64 ir_version = 1;
+
+ // The OperatorSets this model relies on.
+ // All ModelProtos MUST have at least one entry that
+ // specifies which version of the ONNX OperatorSet is
+ // being imported.
+ //
+ // All nodes in the ModelProto's graph will bind against the operator
+ // with the same-domain/same-op_type operator with the HIGHEST version
+ // in the referenced operator sets.
+ repeated OperatorSetIdProto opset_import = 8;
+
+ // The name of the framework or tool used to generate this model.
+ // This field SHOULD be present to indicate which implementation/tool/framework
+ // emitted the model.
+ string producer_name = 2;
+
+ // The version of the framework or tool used to generate this model.
+ // This field SHOULD be present to indicate which implementation/tool/framework
+ // emitted the model.
+ string producer_version = 3;
+
+ // Domain name of the model.
+ // We use reverse domain names as name space indicators. For example:
+ // `com.facebook.fair` or `com.microsoft.cognitiveservices`
+ //
+ // Together with `model_version` and GraphProto.name, this forms the unique identity of
+ // the graph.
+ string domain = 4;
+
+ // The version of the graph encoded. See Version enum below.
+ int64 model_version = 5;
+
+ // A human-readable documentation for this model. Markdown is allowed.
+ string doc_string = 6;
+
+ // The parameterized graph that is evaluated to execute the model.
+ GraphProto graph = 7;
+
+ // Named metadata values; keys should be distinct.
+ repeated StringStringEntryProto metadata_props = 14;
+
+ // Training-specific information. Sequentially executing all stored
+ // `TrainingInfoProto.algorithm`s and assigning their outputs following
+ // the corresponding `TrainingInfoProto.update_binding`s is one training
+ // iteration. Similarly, to initialize the model
+ // (as if training hasn't happened), the user should sequentially execute
+ // all stored `TrainingInfoProto.initialization`s and assigns their outputs
+ // using `TrainingInfoProto.initialization_binding`s.
+ //
+ // If this field is empty, the training behavior of the model is undefined.
+ repeated TrainingInfoProto training_info = 20;
+
+ // A list of function protos local to the model.
+ //
+ // Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
+ // In case of any conflicts the behavior (whether the model local functions are given higher priority,
+ // or standard operator sets are given higher priotity or this is treated as error) is defined by
+ // the runtimes.
+ //
+ // The operator sets imported by FunctionProto should be compatible with the ones
+ // imported by ModelProto and other model local FunctionProtos.
+ // Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
+ // or by 2 FunctionProtos then versions for the operator set may be different but,
+ // the operator schema returned for op_type, domain, version combination
+ // for both the versions should be same for every node in the function body.
+ //
+ // One FunctionProto can reference other FunctionProto in the model, however, recursive reference
+ // is not allowed.
+ repeated FunctionProto functions = 25;
+};
+
+// StringStringEntryProto follows the pattern for cross-proto-version maps.
+// See https://developers.google.com/protocol-buffers/docs/proto3#maps
+message StringStringEntryProto {
+ string key = 1;
+ string value = 2;
+};
+
+message TensorAnnotation {
+ string tensor_name = 1;
+ // <key, value> pairs to annotate tensor specified by <tensor_name> above.
+ // The keys used in the mapping below must be pre-defined in ONNX spec.
+ // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as
+ // quantization parameter keys.
+ repeated StringStringEntryProto quant_parameter_tensor_names = 2;
+}
+
+
+
+// Graphs
+//
+// A graph defines the computational logic of a model and is comprised of a parameterized
+// list of nodes that form a directed acyclic graph based on their inputs and outputs.
+// This is the equivalent of the "network" or "graph" in many deep learning
+// frameworks.
+message GraphProto {
+ // The nodes in the graph, sorted topologically.
+ repeated NodeProto node = 1;
+
+ // The name of the graph.
+ string name = 2; // namespace Graph
+
+ // A list of named tensor values, used to specify constant inputs of the graph.
+ // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
+ // The name MUST be unique across both initializer and sparse_initializer,
+ // but the name MAY also appear in the input list.
+ repeated TensorProto initializer = 5;
+
+ // Initializers (see above) stored in sparse format.
+ repeated SparseTensorProto sparse_initializer = 15;
+
+ // A human-readable documentation for this graph. Markdown is allowed.
+ string doc_string = 10;
+
+ // The inputs and outputs of the graph.
+ repeated ValueInfoProto input = 11;
+ repeated ValueInfoProto output = 12;
+
+ // Information for the values in the graph. The ValueInfoProto.name's
+ // must be distinct. It is optional for a value to appear in value_info list.
+ repeated ValueInfoProto value_info = 13;
+
+ // This field carries information to indicate the mapping among a tensor and its
+ // quantization parameter tensors. For example:
+ // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,
+ // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
+ repeated TensorAnnotation quantization_annotation = 14;
+
+ reserved 3, 4, 6 to 9;
+ reserved "ir_version", "producer_version", "producer_tag", "domain";
+}
+
+// Tensors
+//
+// A serialized tensor value.
+message TensorProto {
+ enum DataType {
+ UNDEFINED = 0;
+ // Basic types.
+ FLOAT = 1; // float
+ UINT8 = 2; // uint8_t
+ INT8 = 3; // int8_t
+ UINT16 = 4; // uint16_t
+ INT16 = 5; // int16_t
+ INT32 = 6; // int32_t
+ INT64 = 7; // int64_t
+ STRING = 8; // string
+ BOOL = 9; // bool
+
+ // IEEE754 half-precision floating-point format (16 bits wide).
+ // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
+ FLOAT16 = 10;
+
+ DOUBLE = 11;
+ UINT32 = 12;
+ UINT64 = 13;
+ COMPLEX64 = 14; // complex with float32 real and imaginary components
+ COMPLEX128 = 15; // complex with float64 real and imaginary components
+
+ // Non-IEEE floating-point format based on IEEE754 single-precision
+ // floating-point number truncated to 16 bits.
+ // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
+ BFLOAT16 = 16;
+
+ // Non-IEEE floating-point format based on papers
+ // FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,
+ // 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.
+ // Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
+ // The computation usually happens inside a block quantize / dequantize
+ // fused by the runtime.
+ FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
+ FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
+ FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
+ FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
+
+ // Future extensions go here.
+ }
+
+ // The shape of the tensor.
+ repeated int64 dims = 1;
+
+ // The data type of the tensor.
+ // This field MUST have a valid TensorProto.DataType value
+ int32 data_type = 2;
+
+ // For very large tensors, we may want to store them in chunks, in which
+ // case the following fields will specify the segment that is stored in
+ // the current TensorProto.
+ message Segment {
+ int64 begin = 1;
+ int64 end = 2;
+ }
+ Segment segment = 3;
+
+ // Tensor content must be organized in row-major order.
+ //
+ // Depending on the data_type field, exactly one of the fields below with
+ // name ending in _data is used to store the elements of the tensor.
+
+ // For float and complex64 values
+ // Complex64 tensors are encoded as a single array of floats,
+ // with the real components appearing in odd numbered positions,
+ // and the corresponding imaginary component appearing in the
+ // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
+ // is encoded as [1.0, 2.0 ,3.0 ,4.0]
+ // When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
+ repeated float float_data = 4 [packed = true];
+
+ // For int32, uint8, int8, uint16, int16, bool, float8, and float16 values
+ // float16 and float8 values must be bit-wise converted to an uint16_t prior
+ // to writing to the buffer.
+ // When this field is present, the data_type field MUST be
+ // INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
+ repeated int32 int32_data = 5 [packed = true];
+
+ // For strings.
+ // Each element of string_data is a UTF-8 encoded Unicode
+ // string. No trailing null, no leading BOM. The protobuf "string"
+ // scalar type is not used to match ML community conventions.
+ // When this field is present, the data_type field MUST be STRING
+ repeated bytes string_data = 6;
+
+ // For int64.
+ // When this field is present, the data_type field MUST be INT64
+ repeated int64 int64_data = 7 [packed = true];
+
+ // Optionally, a name for the tensor.
+ string name = 8; // namespace Value
+
+ // A human-readable documentation for this tensor. Markdown is allowed.
+ string doc_string = 12;
+
+ // Serializations can either use one of the fields above, or use this
+ // raw bytes field. The only exception is the string case, where one is
+ // required to store the content in the repeated bytes string_data field.
+ //
+ // When this raw_data field is used to store tensor value, elements MUST
+ // be stored in as fixed-width, little-endian order.
+ // Floating-point data types MUST be stored in IEEE 754 format.
+ // Complex64 elements must be written as two consecutive FLOAT values, real component first.
+ // Complex128 elements must be written as two consecutive DOUBLE values, real component first.
+ // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
+ //
+ // Note: the advantage of specific field rather than the raw_data field is
+ // that in some cases (e.g. int data), protobuf does a better packing via
+ // variable length storage, and may lead to smaller binary footprint.
+ // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
+ bytes raw_data = 9;
+
+ // Data can be stored inside the protobuf file using type-specific fields or raw_data.
+ // Alternatively, raw bytes data can be stored in an external file, using the external_data field.
+ // external_data stores key-value pairs describing data location. Recognized keys are:
+ // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX
+ // protobuf model was stored
+ // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string.
+ // Offset values SHOULD be multiples 4096 (page size) to enable mmap support.
+ // - "length" (optional) - number of bytes containing data. Integer stored as string.
+ // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key.
+ repeated StringStringEntryProto external_data = 13;
+
+ // Location of the data for this tensor. MUST be one of:
+ // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.
+ // - EXTERNAL - data stored in an external location as described by external_data field.
+ enum DataLocation {
+ DEFAULT = 0;
+ EXTERNAL = 1;
+ }
+
+ // If value not set, data is stored in raw_data (if set) otherwise in type-specified field.
+ DataLocation data_location = 14;
+
+ // For double
+ // Complex128 tensors are encoded as a single array of doubles,
+ // with the real components appearing in odd numbered positions,
+ // and the corresponding imaginary component appearing in the
+ // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
+ // is encoded as [1.0, 2.0 ,3.0 ,4.0]
+ // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
+ repeated double double_data = 10 [packed = true];
+
+ // For uint64 and uint32 values
+ // When this field is present, the data_type field MUST be
+ // UINT32 or UINT64
+ repeated uint64 uint64_data = 11 [packed = true];
+}
+
+// A serialized sparse-tensor value
+message SparseTensorProto {
+ // The sequence of non-default values are encoded as a tensor of shape [NNZ].
+ // The default-value is zero for numeric tensors, and empty-string for string tensors.
+ // values must have a non-empty name present which serves as a name for SparseTensorProto
+ // when used in sparse_initializer list.
+ TensorProto values = 1;
+
+ // The indices of the non-default values, which may be stored in one of two formats.
+ // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value
+ // corresponding to the j-th index of the i-th value (in the values tensor).
+ // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value
+ // must be the linearized-index of the i-th value (in the values tensor).
+ // The linearized-index can be converted into an index tuple (k_1,...,k_rank)
+ // using the shape provided below.
+ // The indices must appear in ascending order without duplication.
+ // In the first format, the ordering is lexicographic-ordering:
+ // e.g., index-value [1,4] must appear before [2,1]
+ TensorProto indices = 2;
+
+ // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank]
+ repeated int64 dims = 3;
+}
+
+// Defines a tensor shape. A dimension can be either an integer value
+// or a symbolic variable. A symbolic variable represents an unknown
+// dimension.
+message TensorShapeProto {
+ message Dimension {
+ oneof value {
+ int64 dim_value = 1;
+ string dim_param = 2; // namespace Shape
+ };
+ // Standard denotation can optionally be used to denote tensor
+ // dimensions with standard semantic descriptions to ensure
+ // that operations are applied to the correct axis of a tensor.
+ // Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
+ // for pre-defined dimension denotations.
+ string denotation = 3;
+ };
+ repeated Dimension dim = 1;
+}
+
+// Types
+//
+// The standard ONNX data types.
+message TypeProto {
+
+ message Tensor {
+ // This field MUST NOT have the value of UNDEFINED
+ // This field MUST have a valid TensorProto.DataType value
+ // This field MUST be present for this version of the IR.
+ int32 elem_type = 1;
+ TensorShapeProto shape = 2;
+ }
+
+ // repeated T
+ message Sequence {
+ // The type and optional shape of each element of the sequence.
+ // This field MUST be present for this version of the IR.
+ TypeProto elem_type = 1;
+ };
+
+ // map<K,V>
+ message Map {
+ // This field MUST have a valid TensorProto.DataType value
+ // This field MUST be present for this version of the IR.
+ // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
+ int32 key_type = 1;
+ // This field MUST be present for this version of the IR.
+ TypeProto value_type = 2;
+ };
+
+ // wrapper for Tensor, Sequence, or Map
+ message Optional {
+ // The type and optional shape of the element wrapped.
+ // This field MUST be present for this version of the IR.
+ // Possible values correspond to OptionalProto.DataType enum
+ TypeProto elem_type = 1;
+ };
+
+
+ message SparseTensor {
+ // This field MUST NOT have the value of UNDEFINED
+ // This field MUST have a valid TensorProto.DataType value
+ // This field MUST be present for this version of the IR.
+ int32 elem_type = 1;
+ TensorShapeProto shape = 2;
+ }
+
+
+ oneof value {
+ // The type of a tensor.
+ Tensor tensor_type = 1;
+
+ // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values
+ // as input and output to graphs and nodes. These types are needed to naturally
+ // support classical ML operators. DNN operators SHOULD restrict their input
+ // and output types to tensors.
+
+ // The type of a sequence.
+ Sequence sequence_type = 4;
+
+ // The type of a map.
+ Map map_type = 5;
+
+ // The type of an optional.
+ Optional optional_type = 9;
+
+
+ // Type of the sparse tensor
+ SparseTensor sparse_tensor_type = 8;
+
+ }
+
+ // An optional denotation can be used to denote the whole
+ // type with a standard semantic description as to what is
+ // stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition
+ // for pre-defined type denotations.
+ string denotation = 6;
+}
+
+// Operator Sets
+//
+// OperatorSets are uniquely identified by a (domain, opset_version) pair.
+message OperatorSetIdProto {
+ // The domain of the operator set being identified.
+ // The empty string ("") or absence of this field implies the operator
+ // set that is defined as part of the ONNX specification.
+ // This field MUST be present in this version of the IR when referring to any other operator set.
+ string domain = 1;
+
+ // The version of the operator set being identified.
+ // This field MUST be present in this version of the IR.
+ int64 version = 2;
+}
+
+// Operator/function status.
+enum OperatorStatus {
+ EXPERIMENTAL = 0;
+ STABLE = 1;
+}
+
+message FunctionProto {
+ // The name of the function, similar usage of op_type in OperatorProto.
+ // Combined with FunctionProto.domain, this forms the unique identity of
+ // the FunctionProto.
+ string name = 1;
+
+ // Deprecated since IR Version 8
+ // optional int64 since_version = 2;
+ reserved 2;
+ reserved "since_version";
+
+ // Deprecated since IR Version 8
+ // optional OperatorStatus status = 3;
+ reserved 3;
+ reserved "status";
+
+ // The inputs and outputs of the function.
+ repeated string input = 4;
+ repeated string output = 5;
+
+ // The attribute parameters of the function.
+ // It is for function parameters without default values.
+ repeated string attribute = 6;
+
+ // The attribute protos of the function.
+ // It is for function attributes with default values.
+ // A function attribute shall be represented either as
+ // a string attribute or an AttributeProto, not both.
+ repeated AttributeProto attribute_proto = 11;
+
+ // The nodes in the function.
+ repeated NodeProto node = 7;
+ // A human-readable documentation for this function. Markdown is allowed.
+ string doc_string = 8;
+
+ // The OperatorSets this function body (graph) relies on.
+ //
+ // All nodes in the function body (graph) will bind against the operator
+ // with the same-domain/same-op_type operator with the HIGHEST version
+ // in the referenced operator sets. This means at most one version can be relied
+ // for one domain.
+ //
+ // The operator sets imported by FunctionProto should be compatible with the ones
+ // imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
+ // and ModelProto then versions for the operator set may be different but,
+ // the operator schema returned for op_type, domain, version combination
+ // for both the versions should be same.
+
+ repeated OperatorSetIdProto opset_import = 9;
+
+ // The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
+ // the FunctionProto.
+ string domain = 10;
+}
+
+// For using protobuf-lite
+option optimize_for = LITE_RUNTIME;
+
diff --git a/test.onnx b/test.onnx
new file mode 100644
index 00000000..b69364d1
--- /dev/null
+++ b/test.onnx
@@ -0,0 +1,12 @@
+ backend-test:J
+
+xytest"Relu
+SingleReluZ
+x
+ 
+
+b
+y
+ 
+
+B \ No newline at end of file