From f3a4f3db768d46defc16de48208107db1b32159d Mon Sep 17 00:00:00 2001
From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com>
Date: Wed, 8 Nov 2023 06:37:50 +0100
Subject: PyO3: Add optional `candle.onnx` module (#1282)

* Start onnx integration

* Merge remote-tracking branch 'upstream/main' into feat/pyo3-onnx

* Implement ONNXModel

* `fmt`

* add `onnx` flag to python ci

* Pin `protoc` to `25.0`

* Setup `protoc` in wheel builds

* Build wheels with `onnx`

* Install `protoc` in manylinux containers

* `apt` -> `yum`

* Download `protoc` via bash script

* Back to `manylinux: auto`

* Disable `onnx` builds for linux
---
 candle-pyo3/Cargo.toml                      |   3 +
 candle-pyo3/py_src/candle/onnx/__init__.py  |   5 +
 candle-pyo3/py_src/candle/onnx/__init__.pyi |  89 ++++++++++++
 candle-pyo3/src/lib.rs                      |  22 ++-
 candle-pyo3/src/onnx.rs                     | 212 ++++++++++++++++++++++++++++
 candle-pyo3/src/utils.rs                    |   6 +
 6 files changed, 334 insertions(+), 3 deletions(-)
 create mode 100644 candle-pyo3/py_src/candle/onnx/__init__.py
 create mode 100644 candle-pyo3/py_src/candle/onnx/__init__.pyi
 create mode 100644 candle-pyo3/src/onnx.rs
 create mode 100644 candle-pyo3/src/utils.rs

(limited to 'candle-pyo3')

diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml
index b0452404..f79277f2 100644
--- a/candle-pyo3/Cargo.toml
+++ b/candle-pyo3/Cargo.toml
@@ -17,6 +17,7 @@ crate-type = ["cdylib"]
 accelerate-src = { workspace = true, optional = true }
 candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
 candle-nn = { path = "../candle-nn", version = "0.3.0" }
+candle-onnx = {path= "../candle-onnx", version = "0.3.0", optional = true}
 half = { workspace = true }
 intel-mkl-src = { workspace = true, optional = true }
 pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
@@ -29,3 +30,5 @@ default = []
 accelerate = ["dep:accelerate-src", "candle/accelerate"]
 cuda = ["candle/cuda"]
 mkl = ["dep:intel-mkl-src","candle/mkl"]
+onnx = ["dep:candle-onnx"]
+
diff --git a/candle-pyo3/py_src/candle/onnx/__init__.py b/candle-pyo3/py_src/candle/onnx/__init__.py
new file mode 100644
index 00000000..856ecd7d
--- /dev/null
+++ b/candle-pyo3/py_src/candle/onnx/__init__.py
@@ -0,0 +1,5 @@
+# Generated content DO NOT EDIT
+from .. import onnx
+
+ONNXModel = onnx.ONNXModel
+ONNXTensorDescription = onnx.ONNXTensorDescription
diff --git a/candle-pyo3/py_src/candle/onnx/__init__.pyi b/candle-pyo3/py_src/candle/onnx/__init__.pyi
new file mode 100644
index 00000000..8ce1b3aa
--- /dev/null
+++ b/candle-pyo3/py_src/candle/onnx/__init__.pyi
@@ -0,0 +1,89 @@
+# Generated content DO NOT EDIT
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
+from os import PathLike
+from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
+from candle import Tensor, DType, QTensor
+
+class ONNXModel:
+    """
+    A wrapper around an ONNX model.
+    """
+
+    def __init__(self, path: str):
+        pass
+    @property
+    def doc_string(self) -> str:
+        """
+        The doc string of the model.
+        """
+        pass
+    @property
+    def domain(self) -> str:
+        """
+        The domain of the operator set of the model.
+        """
+        pass
+    def initializers(self) -> Dict[str, Tensor]:
+        """
+        Get the weights of the model.
+        """
+        pass
+    @property
+    def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
+        """
+        The inputs of the model.
+        """
+        pass
+    @property
+    def ir_version(self) -> int:
+        """
+        The version of the IR this model targets.
+        """
+        pass
+    @property
+    def model_version(self) -> int:
+        """
+        The version of the model.
+        """
+        pass
+    @property
+    def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
+        """
+        The outputs of the model.
+        """
+        pass
+    @property
+    def producer_name(self) -> str:
+        """
+        The producer of the model.
+        """
+        pass
+    @property
+    def producer_version(self) -> str:
+        """
+        The version of the producer of the model.
+        """
+        pass
+    def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
+        """
+        Run the model on the given inputs.
+        """
+        pass
+
+class ONNXTensorDescription:
+    """
+    A wrapper around an ONNX tensor description.
+    """
+
+    @property
+    def dtype(self) -> DType:
+        """
+        The data type of the tensor.
+        """
+        pass
+    @property
+    def shape(self) -> Tuple[Union[int, str, Any]]:
+        """
+        The shape of the tensor.
+        """
+        pass
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index ddd58fbe..05a786ef 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -19,12 +19,14 @@ extern crate accelerate_src;
 
 use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
 
+mod utils;
+use utils::wrap_err;
+
 mod shape;
 use shape::{PyShape, PyShapeWithHole};
 
-pub fn wrap_err(err: ::candle::Error) -> PyErr {
-    PyErr::new::<PyValueError, _>(format!("{err:?}"))
-}
+#[cfg(feature = "onnx")]
+mod onnx;
 
 #[derive(Clone, Debug)]
 #[pyclass(name = "Tensor")]
@@ -1559,6 +1561,14 @@ fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
     Ok(())
 }
 
+#[cfg(feature = "onnx")]
+fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
+    use onnx::{PyONNXModel, PyONNXTensorDescriptor};
+    m.add_class::<PyONNXModel>()?;
+    m.add_class::<PyONNXTensorDescriptor>()?;
+    Ok(())
+}
+
 #[pymodule]
 fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
     let utils = PyModule::new(py, "utils")?;
@@ -1567,6 +1577,12 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
     let nn = PyModule::new(py, "functional")?;
     candle_functional_m(py, nn)?;
     m.add_submodule(nn)?;
+    #[cfg(feature = "onnx")]
+    {
+        let onnx = PyModule::new(py, "onnx")?;
+        candle_onnx_m(py, onnx)?;
+        m.add_submodule(onnx)?;
+    }
     m.add_class::<PyTensor>()?;
     m.add_class::<PyQTensor>()?;
     m.add_class::<PyDType>()?;
diff --git a/candle-pyo3/src/onnx.rs b/candle-pyo3/src/onnx.rs
new file mode 100644
index 00000000..b9a0eb22
--- /dev/null
+++ b/candle-pyo3/src/onnx.rs
@@ -0,0 +1,212 @@
+use std::collections::HashMap;
+
+use crate::utils::wrap_err;
+use crate::{PyDType, PyTensor};
+use candle_onnx::eval::{dtype, get_tensor, simple_eval};
+use candle_onnx::onnx::tensor_proto::DataType;
+use candle_onnx::onnx::tensor_shape_proto::dimension::Value;
+use candle_onnx::onnx::type_proto::{Tensor as ONNXTensor, Value as ONNXValue};
+use candle_onnx::onnx::{ModelProto, ValueInfoProto};
+use pyo3::exceptions::PyValueError;
+use pyo3::prelude::*;
+use pyo3::types::{PyList, PyTuple};
+
+#[derive(Clone, Debug)]
+#[pyclass(name = "ONNXTensorDescription")]
+/// A wrapper around an ONNX tensor description.
+pub struct PyONNXTensorDescriptor(ONNXTensor);
+
+#[pymethods]
+impl PyONNXTensorDescriptor {
+    #[getter]
+    /// The data type of the tensor.
+    /// &RETURNS&: DType
+    fn dtype(&self) -> PyResult<PyDType> {
+        match DataType::try_from(self.0.elem_type) {
+            Ok(dt) => match dtype(dt) {
+                Some(dt) => Ok(PyDType(dt)),
+                None => Err(PyValueError::new_err(format!(
+                    "unsupported 'value' data-type {dt:?}"
+                ))),
+            },
+            type_ => Err(PyValueError::new_err(format!(
+                "unsupported input type {type_:?}"
+            ))),
+        }
+    }
+
+    #[getter]
+    /// The shape of the tensor.
+    /// &RETURNS&: Tuple[Union[int,str,Any]]
+    fn shape(&self, py: Python) -> PyResult<Py<PyTuple>> {
+        let shape = PyList::empty(py);
+        if let Some(d) = &self.0.shape {
+            for dim in d.dim.iter() {
+                if let Some(value) = &dim.value {
+                    match value {
+                        Value::DimValue(v) => shape.append(*v)?,
+                        Value::DimParam(s) => shape.append(s.clone())?,
+                    };
+                } else {
+                    return Err(PyValueError::new_err("None value in shape"));
+                }
+            }
+        }
+        Ok(shape.to_tuple().into())
+    }
+
+    fn __repr__(&self, py: Python) -> String {
+        match (self.shape(py), self.dtype()) {
+            (Ok(shape), Ok(dtype)) => format!(
+                "TensorDescriptor[shape: {:?}, dtype: {:?}]",
+                shape.to_string(),
+                dtype.__str__()
+            ),
+            (Err(_), Err(_)) => "TensorDescriptor[shape: unknown, dtype: unknown]".to_string(),
+            (Err(_), Ok(dtype)) => format!(
+                "TensorDescriptor[shape: unknown, dtype: {:?}]",
+                dtype.__str__()
+            ),
+            (Ok(shape), Err(_)) => format!(
+                "TensorDescriptor[shape: {:?}, dtype: unknown]",
+                shape.to_string()
+            ),
+        }
+    }
+
+    fn __str__(&self, py: Python) -> String {
+        self.__repr__(py)
+    }
+}
+
+#[derive(Clone, Debug)]
+#[pyclass(name = "ONNXModel")]
+/// A wrapper around an ONNX model.
+pub struct PyONNXModel(ModelProto);
+
+fn extract_tensor_descriptions(
+    value_infos: &[ValueInfoProto],
+) -> HashMap<String, PyONNXTensorDescriptor> {
+    let mut map = HashMap::new();
+    for value_info in value_infos.iter() {
+        let input_type = match &value_info.r#type {
+            Some(input_type) => input_type,
+            None => continue,
+        };
+        let input_type = match &input_type.value {
+            Some(input_type) => input_type,
+            None => continue,
+        };
+
+        let tensor_type: &ONNXTensor = match input_type {
+            ONNXValue::TensorType(tt) => tt,
+            _ => continue,
+        };
+        map.insert(
+            value_info.name.to_string(),
+            PyONNXTensorDescriptor(tensor_type.clone()),
+        );
+    }
+    map
+}
+
+#[pymethods]
+impl PyONNXModel {
+    #[new]
+    #[pyo3(text_signature = "(self, path:str)")]
+    /// Load an ONNX model from the given path.
+    fn new(path: String) -> PyResult<Self> {
+        let model: ModelProto = candle_onnx::read_file(path).map_err(wrap_err)?;
+        Ok(PyONNXModel(model))
+    }
+
+    #[getter]
+    /// The version of the IR this model targets.
+    /// &RETURNS&: int
+    fn ir_version(&self) -> i64 {
+        self.0.ir_version
+    }
+
+    #[getter]
+    /// The producer of the model.  
+    /// &RETURNS&: str      
+    fn producer_name(&self) -> String {
+        self.0.producer_name.clone()
+    }
+
+    #[getter]
+    /// The version of the producer of the model.       
+    /// &RETURNS&: str
+    fn producer_version(&self) -> String {
+        self.0.producer_version.clone()
+    }
+
+    #[getter]
+    /// The domain of the operator set of the model.
+    /// &RETURNS&: str
+    fn domain(&self) -> String {
+        self.0.domain.clone()
+    }
+
+    #[getter]
+    /// The version of the model.
+    /// &RETURNS&: int
+    fn model_version(&self) -> i64 {
+        self.0.model_version
+    }
+
+    #[getter]
+    /// The doc string of the model.
+    /// &RETURNS&: str
+    fn doc_string(&self) -> String {
+        self.0.doc_string.clone()
+    }
+
+    /// Get the weights of the model.
+    /// &RETURNS&: Dict[str, Tensor]
+    fn initializers(&self) -> PyResult<HashMap<String, PyTensor>> {
+        let mut map = HashMap::new();
+        if let Some(graph) = self.0.graph.as_ref() {
+            for tensor_description in graph.initializer.iter() {
+                let tensor = get_tensor(tensor_description, tensor_description.name.as_str())
+                    .map_err(wrap_err)?;
+                map.insert(tensor_description.name.to_string(), PyTensor(tensor));
+            }
+        }
+        Ok(map)
+    }
+
+    #[getter]
+    /// The inputs of the model.
+    /// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]]
+    fn inputs(&self) -> Option<HashMap<String, PyONNXTensorDescriptor>> {
+        if let Some(graph) = self.0.graph.as_ref() {
+            return Some(extract_tensor_descriptions(&graph.input));
+        }
+        None
+    }
+
+    #[getter]
+    /// The outputs of the model.
+    /// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]]
+    fn outputs(&self) -> Option<HashMap<String, PyONNXTensorDescriptor>> {
+        if let Some(graph) = self.0.graph.as_ref() {
+            return Some(extract_tensor_descriptions(&graph.output));
+        }
+        None
+    }
+
+    #[pyo3(text_signature = "(self, inputs:Dict[str,Tensor])")]
+    /// Run the model on the given inputs.
+    /// &RETURNS&: Dict[str,Tensor]
+    fn run(&self, inputs: HashMap<String, PyTensor>) -> PyResult<HashMap<String, PyTensor>> {
+        let unwrapped_tensors = inputs.into_iter().map(|(k, v)| (k.clone(), v.0)).collect();
+
+        let result = simple_eval(&self.0, unwrapped_tensors).map_err(wrap_err)?;
+
+        Ok(result
+            .into_iter()
+            .map(|(k, v)| (k.clone(), PyTensor(v)))
+            .collect())
+    }
+}
diff --git a/candle-pyo3/src/utils.rs b/candle-pyo3/src/utils.rs
new file mode 100644
index 00000000..ad0a76a5
--- /dev/null
+++ b/candle-pyo3/src/utils.rs
@@ -0,0 +1,6 @@
+use pyo3::exceptions::PyValueError;
+use pyo3::prelude::*;
+
+pub fn wrap_err(err: ::candle::Error) -> PyErr {
+    PyErr::new::<PyValueError, _>(format!("{err:?}"))
+}
-- 
cgit v1.2.3