summaryrefslogtreecommitdiff
path: root/candle-pyo3/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/src/lib.rs')
-rw-r--r--candle-pyo3/src/lib.rs22
1 files changed, 19 insertions, 3 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index ddd58fbe..05a786ef 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -19,12 +19,14 @@ extern crate accelerate_src;
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
+mod utils;
+use utils::wrap_err;
+
mod shape;
use shape::{PyShape, PyShapeWithHole};
-pub fn wrap_err(err: ::candle::Error) -> PyErr {
- PyErr::new::<PyValueError, _>(format!("{err:?}"))
-}
+#[cfg(feature = "onnx")]
+mod onnx;
#[derive(Clone, Debug)]
#[pyclass(name = "Tensor")]
@@ -1559,6 +1561,14 @@ fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
Ok(())
}
+#[cfg(feature = "onnx")]
+fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
+ use onnx::{PyONNXModel, PyONNXTensorDescriptor};
+ m.add_class::<PyONNXModel>()?;
+ m.add_class::<PyONNXTensorDescriptor>()?;
+ Ok(())
+}
+
#[pymodule]
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let utils = PyModule::new(py, "utils")?;
@@ -1567,6 +1577,12 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let nn = PyModule::new(py, "functional")?;
candle_functional_m(py, nn)?;
m.add_submodule(nn)?;
+ #[cfg(feature = "onnx")]
+ {
+ let onnx = PyModule::new(py, "onnx")?;
+ candle_onnx_m(py, onnx)?;
+ m.add_submodule(onnx)?;
+ }
m.add_class::<PyTensor>()?;
m.add_class::<PyQTensor>()?;
m.add_class::<PyDType>()?;