summaryrefslogtreecommitdiff
path: root/candle-onnx/examples/onnx_basics.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-onnx/examples/onnx_basics.rs')
-rw-r--r--candle-onnx/examples/onnx_basics.rs56
1 files changed, 56 insertions, 0 deletions
diff --git a/candle-onnx/examples/onnx_basics.rs b/candle-onnx/examples/onnx_basics.rs
new file mode 100644
index 00000000..b91cbee6
--- /dev/null
+++ b/candle-onnx/examples/onnx_basics.rs
@@ -0,0 +1,56 @@
+use anyhow::Result;
+use candle::{Device, Tensor};
+
+use clap::{Parser, Subcommand};
+
+#[derive(Subcommand, Debug, Clone)]
+enum Command {
+ Print {
+ #[arg(long)]
+ file: String,
+ },
+ SimpleEval {
+ #[arg(long)]
+ file: String,
+ },
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+pub struct Args {
+ #[command(subcommand)]
+ command: Command,
+}
+
+pub fn main() -> Result<()> {
+ let args = Args::parse();
+ match args.command {
+ Command::Print { file } => {
+ let model = candle_onnx::read_file(file)?;
+ println!("{model:?}");
+ let graph = model.graph.unwrap();
+ for node in graph.node.iter() {
+ println!("{node:?}");
+ }
+ }
+ Command::SimpleEval { file } => {
+ let model = candle_onnx::read_file(file)?;
+ let inputs = model
+ .graph
+ .as_ref()
+ .unwrap()
+ .input
+ .iter()
+ .map(|name| {
+ let value = Tensor::new(&[-3.2, 2.7], &Device::Cpu)?;
+ Ok((name.name.clone(), value))
+ })
+ .collect::<Result<_>>()?;
+ let outputs = candle_onnx::simple_eval(&model, inputs)?;
+ for (name, value) in outputs.iter() {
+ println!("{name}: {value:?}")
+ }
+ }
+ }
+ Ok(())
+}