summaryrefslogtreecommitdiff
path: root/candle-onnx/src
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-11-08 06:37:50 +0100
committerGitHub <noreply@github.com>2023-11-08 06:37:50 +0100
commitf3a4f3db768d46defc16de48208107db1b32159d (patch)
tree21ae0872e46621656559ec0caf6d7625e6ca7e76 /candle-onnx/src
parent7920b45c8ac737b67e23f04297f6bd7e4860f373 (diff)
downloadcandle-f3a4f3db768d46defc16de48208107db1b32159d.tar.gz
candle-f3a4f3db768d46defc16de48208107db1b32159d.tar.bz2
candle-f3a4f3db768d46defc16de48208107db1b32159d.zip
PyO3: Add optional `candle.onnx` module (#1282)
* Start onnx integration * Merge remote-tracking branch 'upstream/main' into feat/pyo3-onnx * Implement ONNXModel * `fmt` * add `onnx` flag to python ci * Pin `protoc` to `25.0` * Setup `protoc` in wheel builds * Build wheels with `onnx` * Install `protoc` in manylinux containers * `apt` -> `yum` * Download `protoc` via bash script * Back to `manylinux: auto` * Disable `onnx` builds for linux
Diffstat (limited to 'candle-onnx/src')
-rw-r--r--candle-onnx/src/eval.rs2
-rw-r--r--candle-onnx/src/lib.rs2
2 files changed, 2 insertions, 2 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index 51e2aa0c..b7e325e1 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -98,7 +98,7 @@ fn get_attr_opt<'a, T: Attr + ?Sized>(
}
}
-fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
+pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
match DataType::try_from(t.data_type) {
Ok(DataType::Int32) => {
diff --git a/candle-onnx/src/lib.rs b/candle-onnx/src/lib.rs
index 1002a2c8..efd6f760 100644
--- a/candle-onnx/src/lib.rs
+++ b/candle-onnx/src/lib.rs
@@ -5,7 +5,7 @@ pub mod onnx {
include!(concat!(env!("OUT_DIR"), "/onnx.rs"));
}
-mod eval;
+pub mod eval;
pub use eval::{dtype, simple_eval};
pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {