summaryrefslogtreecommitdiff
path: root/candle-pyo3/src/onnx.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/src/onnx.rs')
-rw-r--r--candle-pyo3/src/onnx.rs212
1 files changed, 212 insertions, 0 deletions
diff --git a/candle-pyo3/src/onnx.rs b/candle-pyo3/src/onnx.rs
new file mode 100644
index 00000000..b9a0eb22
--- /dev/null
+++ b/candle-pyo3/src/onnx.rs
@@ -0,0 +1,212 @@
+use std::collections::HashMap;
+
+use crate::utils::wrap_err;
+use crate::{PyDType, PyTensor};
+use candle_onnx::eval::{dtype, get_tensor, simple_eval};
+use candle_onnx::onnx::tensor_proto::DataType;
+use candle_onnx::onnx::tensor_shape_proto::dimension::Value;
+use candle_onnx::onnx::type_proto::{Tensor as ONNXTensor, Value as ONNXValue};
+use candle_onnx::onnx::{ModelProto, ValueInfoProto};
+use pyo3::exceptions::PyValueError;
+use pyo3::prelude::*;
+use pyo3::types::{PyList, PyTuple};
+
+#[derive(Clone, Debug)]
+#[pyclass(name = "ONNXTensorDescription")]
+/// A wrapper around an ONNX tensor description.
+pub struct PyONNXTensorDescriptor(ONNXTensor);
+
+#[pymethods]
+impl PyONNXTensorDescriptor {
+ #[getter]
+ /// The data type of the tensor.
+ /// &RETURNS&: DType
+ fn dtype(&self) -> PyResult<PyDType> {
+ match DataType::try_from(self.0.elem_type) {
+ Ok(dt) => match dtype(dt) {
+ Some(dt) => Ok(PyDType(dt)),
+ None => Err(PyValueError::new_err(format!(
+ "unsupported 'value' data-type {dt:?}"
+ ))),
+ },
+ type_ => Err(PyValueError::new_err(format!(
+ "unsupported input type {type_:?}"
+ ))),
+ }
+ }
+
+ #[getter]
+ /// The shape of the tensor.
+ /// &RETURNS&: Tuple[Union[int,str,Any]]
+ fn shape(&self, py: Python) -> PyResult<Py<PyTuple>> {
+ let shape = PyList::empty(py);
+ if let Some(d) = &self.0.shape {
+ for dim in d.dim.iter() {
+ if let Some(value) = &dim.value {
+ match value {
+ Value::DimValue(v) => shape.append(*v)?,
+ Value::DimParam(s) => shape.append(s.clone())?,
+ };
+ } else {
+ return Err(PyValueError::new_err("None value in shape"));
+ }
+ }
+ }
+ Ok(shape.to_tuple().into())
+ }
+
+ fn __repr__(&self, py: Python) -> String {
+ match (self.shape(py), self.dtype()) {
+ (Ok(shape), Ok(dtype)) => format!(
+ "TensorDescriptor[shape: {:?}, dtype: {:?}]",
+ shape.to_string(),
+ dtype.__str__()
+ ),
+ (Err(_), Err(_)) => "TensorDescriptor[shape: unknown, dtype: unknown]".to_string(),
+ (Err(_), Ok(dtype)) => format!(
+ "TensorDescriptor[shape: unknown, dtype: {:?}]",
+ dtype.__str__()
+ ),
+ (Ok(shape), Err(_)) => format!(
+ "TensorDescriptor[shape: {:?}, dtype: unknown]",
+ shape.to_string()
+ ),
+ }
+ }
+
+ fn __str__(&self, py: Python) -> String {
+ self.__repr__(py)
+ }
+}
+
+#[derive(Clone, Debug)]
+#[pyclass(name = "ONNXModel")]
+/// A wrapper around an ONNX model.
+pub struct PyONNXModel(ModelProto);
+
+fn extract_tensor_descriptions(
+ value_infos: &[ValueInfoProto],
+) -> HashMap<String, PyONNXTensorDescriptor> {
+ let mut map = HashMap::new();
+ for value_info in value_infos.iter() {
+ let input_type = match &value_info.r#type {
+ Some(input_type) => input_type,
+ None => continue,
+ };
+ let input_type = match &input_type.value {
+ Some(input_type) => input_type,
+ None => continue,
+ };
+
+ let tensor_type: &ONNXTensor = match input_type {
+ ONNXValue::TensorType(tt) => tt,
+ _ => continue,
+ };
+ map.insert(
+ value_info.name.to_string(),
+ PyONNXTensorDescriptor(tensor_type.clone()),
+ );
+ }
+ map
+}
+
+#[pymethods]
+impl PyONNXModel {
+ #[new]
+ #[pyo3(text_signature = "(self, path:str)")]
+ /// Load an ONNX model from the given path.
+ fn new(path: String) -> PyResult<Self> {
+ let model: ModelProto = candle_onnx::read_file(path).map_err(wrap_err)?;
+ Ok(PyONNXModel(model))
+ }
+
+ #[getter]
+ /// The version of the IR this model targets.
+ /// &RETURNS&: int
+ fn ir_version(&self) -> i64 {
+ self.0.ir_version
+ }
+
+ #[getter]
+ /// The producer of the model.
+ /// &RETURNS&: str
+ fn producer_name(&self) -> String {
+ self.0.producer_name.clone()
+ }
+
+ #[getter]
+ /// The version of the producer of the model.
+ /// &RETURNS&: str
+ fn producer_version(&self) -> String {
+ self.0.producer_version.clone()
+ }
+
+ #[getter]
+ /// The domain of the operator set of the model.
+ /// &RETURNS&: str
+ fn domain(&self) -> String {
+ self.0.domain.clone()
+ }
+
+ #[getter]
+ /// The version of the model.
+ /// &RETURNS&: int
+ fn model_version(&self) -> i64 {
+ self.0.model_version
+ }
+
+ #[getter]
+ /// The doc string of the model.
+ /// &RETURNS&: str
+ fn doc_string(&self) -> String {
+ self.0.doc_string.clone()
+ }
+
+ /// Get the weights of the model.
+ /// &RETURNS&: Dict[str, Tensor]
+ fn initializers(&self) -> PyResult<HashMap<String, PyTensor>> {
+ let mut map = HashMap::new();
+ if let Some(graph) = self.0.graph.as_ref() {
+ for tensor_description in graph.initializer.iter() {
+ let tensor = get_tensor(tensor_description, tensor_description.name.as_str())
+ .map_err(wrap_err)?;
+ map.insert(tensor_description.name.to_string(), PyTensor(tensor));
+ }
+ }
+ Ok(map)
+ }
+
+ #[getter]
+ /// The inputs of the model.
+ /// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]]
+ fn inputs(&self) -> Option<HashMap<String, PyONNXTensorDescriptor>> {
+ if let Some(graph) = self.0.graph.as_ref() {
+ return Some(extract_tensor_descriptions(&graph.input));
+ }
+ None
+ }
+
+ #[getter]
+ /// The outputs of the model.
+ /// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]]
+ fn outputs(&self) -> Option<HashMap<String, PyONNXTensorDescriptor>> {
+ if let Some(graph) = self.0.graph.as_ref() {
+ return Some(extract_tensor_descriptions(&graph.output));
+ }
+ None
+ }
+
+ #[pyo3(text_signature = "(self, inputs:Dict[str,Tensor])")]
+ /// Run the model on the given inputs.
+ /// &RETURNS&: Dict[str,Tensor]
+ fn run(&self, inputs: HashMap<String, PyTensor>) -> PyResult<HashMap<String, PyTensor>> {
+ let unwrapped_tensors = inputs.into_iter().map(|(k, v)| (k.clone(), v.0)).collect();
+
+ let result = simple_eval(&self.0, unwrapped_tensors).map_err(wrap_err)?;
+
+ Ok(result
+ .into_iter()
+ .map(|(k, v)| (k.clone(), PyTensor(v)))
+ .collect())
+ }
+}