diff options
Diffstat (limited to 'candle-pyo3/src/lib.rs')
-rw-r--r-- | candle-pyo3/src/lib.rs | 22 |
1 files changed, 19 insertions, 3 deletions
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index ddd58fbe..05a786ef 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -19,12 +19,14 @@ extern crate accelerate_src; use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType}; +mod utils; +use utils::wrap_err; + mod shape; use shape::{PyShape, PyShapeWithHole}; -pub fn wrap_err(err: ::candle::Error) -> PyErr { - PyErr::new::<PyValueError, _>(format!("{err:?}")) -} +#[cfg(feature = "onnx")] +mod onnx; #[derive(Clone, Debug)] #[pyclass(name = "Tensor")] @@ -1559,6 +1561,14 @@ fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> { Ok(()) } +#[cfg(feature = "onnx")] +fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + use onnx::{PyONNXModel, PyONNXTensorDescriptor}; + m.add_class::<PyONNXModel>()?; + m.add_class::<PyONNXTensorDescriptor>()?; + Ok(()) +} + #[pymodule] fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> { let utils = PyModule::new(py, "utils")?; @@ -1567,6 +1577,12 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> { let nn = PyModule::new(py, "functional")?; candle_functional_m(py, nn)?; m.add_submodule(nn)?; + #[cfg(feature = "onnx")] + { + let onnx = PyModule::new(py, "onnx")?; + candle_onnx_m(py, onnx)?; + m.add_submodule(onnx)?; + } m.add_class::<PyTensor>()?; m.add_class::<PyQTensor>()?; m.add_class::<PyDType>()?; |