1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
|
use std::collections::HashMap;
use crate::utils::wrap_err;
use crate::{PyDType, PyTensor};
use candle_onnx::eval::{dtype, get_tensor, simple_eval};
use candle_onnx::onnx::tensor_proto::DataType;
use candle_onnx::onnx::tensor_shape_proto::dimension::Value;
use candle_onnx::onnx::type_proto::{Tensor as ONNXTensor, Value as ONNXValue};
use candle_onnx::onnx::{ModelProto, ValueInfoProto};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyList, PyTuple};
#[derive(Clone, Debug)]
#[pyclass(name = "ONNXTensorDescription")]
/// A wrapper around an ONNX tensor description.
pub struct PyONNXTensorDescriptor(ONNXTensor);
#[pymethods]
impl PyONNXTensorDescriptor {
#[getter]
/// The data type of the tensor.
/// &RETURNS&: DType
fn dtype(&self) -> PyResult<PyDType> {
match DataType::try_from(self.0.elem_type) {
Ok(dt) => match dtype(dt) {
Some(dt) => Ok(PyDType(dt)),
None => Err(PyValueError::new_err(format!(
"unsupported 'value' data-type {dt:?}"
))),
},
type_ => Err(PyValueError::new_err(format!(
"unsupported input type {type_:?}"
))),
}
}
#[getter]
/// The shape of the tensor.
/// &RETURNS&: Tuple[Union[int,str,Any]]
fn shape(&self, py: Python) -> PyResult<Py<PyTuple>> {
let shape = PyList::empty(py);
if let Some(d) = &self.0.shape {
for dim in d.dim.iter() {
if let Some(value) = &dim.value {
match value {
Value::DimValue(v) => shape.append(*v)?,
Value::DimParam(s) => shape.append(s.clone())?,
};
} else {
return Err(PyValueError::new_err("None value in shape"));
}
}
}
Ok(shape.to_tuple().into())
}
fn __repr__(&self, py: Python) -> String {
match (self.shape(py), self.dtype()) {
(Ok(shape), Ok(dtype)) => format!(
"TensorDescriptor[shape: {:?}, dtype: {:?}]",
shape.to_string(),
dtype.__str__()
),
(Err(_), Err(_)) => "TensorDescriptor[shape: unknown, dtype: unknown]".to_string(),
(Err(_), Ok(dtype)) => format!(
"TensorDescriptor[shape: unknown, dtype: {:?}]",
dtype.__str__()
),
(Ok(shape), Err(_)) => format!(
"TensorDescriptor[shape: {:?}, dtype: unknown]",
shape.to_string()
),
}
}
fn __str__(&self, py: Python) -> String {
self.__repr__(py)
}
}
#[derive(Clone, Debug)]
#[pyclass(name = "ONNXModel")]
/// A wrapper around an ONNX model.
pub struct PyONNXModel(ModelProto);
fn extract_tensor_descriptions(
value_infos: &[ValueInfoProto],
) -> HashMap<String, PyONNXTensorDescriptor> {
let mut map = HashMap::new();
for value_info in value_infos.iter() {
let input_type = match &value_info.r#type {
Some(input_type) => input_type,
None => continue,
};
let input_type = match &input_type.value {
Some(input_type) => input_type,
None => continue,
};
let tensor_type: &ONNXTensor = match input_type {
ONNXValue::TensorType(tt) => tt,
_ => continue,
};
map.insert(
value_info.name.to_string(),
PyONNXTensorDescriptor(tensor_type.clone()),
);
}
map
}
#[pymethods]
impl PyONNXModel {
#[new]
#[pyo3(text_signature = "(self, path:str)")]
/// Load an ONNX model from the given path.
fn new(path: String) -> PyResult<Self> {
let model: ModelProto = candle_onnx::read_file(path).map_err(wrap_err)?;
Ok(PyONNXModel(model))
}
#[getter]
/// The version of the IR this model targets.
/// &RETURNS&: int
fn ir_version(&self) -> i64 {
self.0.ir_version
}
#[getter]
/// The producer of the model.
/// &RETURNS&: str
fn producer_name(&self) -> String {
self.0.producer_name.clone()
}
#[getter]
/// The version of the producer of the model.
/// &RETURNS&: str
fn producer_version(&self) -> String {
self.0.producer_version.clone()
}
#[getter]
/// The domain of the operator set of the model.
/// &RETURNS&: str
fn domain(&self) -> String {
self.0.domain.clone()
}
#[getter]
/// The version of the model.
/// &RETURNS&: int
fn model_version(&self) -> i64 {
self.0.model_version
}
#[getter]
/// The doc string of the model.
/// &RETURNS&: str
fn doc_string(&self) -> String {
self.0.doc_string.clone()
}
/// Get the weights of the model.
/// &RETURNS&: Dict[str, Tensor]
fn initializers(&self) -> PyResult<HashMap<String, PyTensor>> {
let mut map = HashMap::new();
if let Some(graph) = self.0.graph.as_ref() {
for tensor_description in graph.initializer.iter() {
let tensor = get_tensor(tensor_description, tensor_description.name.as_str())
.map_err(wrap_err)?;
map.insert(tensor_description.name.to_string(), PyTensor(tensor));
}
}
Ok(map)
}
#[getter]
/// The inputs of the model.
/// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]]
fn inputs(&self) -> Option<HashMap<String, PyONNXTensorDescriptor>> {
if let Some(graph) = self.0.graph.as_ref() {
return Some(extract_tensor_descriptions(&graph.input));
}
None
}
#[getter]
/// The outputs of the model.
/// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]]
fn outputs(&self) -> Option<HashMap<String, PyONNXTensorDescriptor>> {
if let Some(graph) = self.0.graph.as_ref() {
return Some(extract_tensor_descriptions(&graph.output));
}
None
}
#[pyo3(text_signature = "(self, inputs:Dict[str,Tensor])")]
/// Run the model on the given inputs.
/// &RETURNS&: Dict[str,Tensor]
fn run(&self, inputs: HashMap<String, PyTensor>) -> PyResult<HashMap<String, PyTensor>> {
let unwrapped_tensors = inputs.into_iter().map(|(k, v)| (k.clone(), v.0)).collect();
let result = simple_eval(&self.0, unwrapped_tensors).map_err(wrap_err)?;
Ok(result
.into_iter()
.map(|(k, v)| (k.clone(), PyTensor(v)))
.collect())
}
}
|