summaryrefslogtreecommitdiff
path: root/candle-examples/examples/onnx_basics.rs
blob: 0a1737171471d5b1b3e8408d152f2653f0e4d7d3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
use anyhow::Result;
use candle::{Device, Tensor};

use clap::{Parser, Subcommand};

#[derive(Subcommand, Debug, Clone)]
enum Command {
    Print {
        #[arg(long)]
        file: String,
    },
    SimpleEval {
        #[arg(long)]
        file: String,
    },
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
    #[command(subcommand)]
    command: Command,
}

pub fn main() -> Result<()> {
    let args = Args::parse();
    match args.command {
        Command::Print { file } => {
            let model = candle_onnx::read_file(file)?;
            println!("{model:?}");
            let graph = model.graph.unwrap();
            for node in graph.node.iter() {
                println!("{node:?}");
            }
        }
        Command::SimpleEval { file } => {
            let model = candle_onnx::read_file(file)?;
            let graph = model.graph.as_ref().unwrap();
            let constants: std::collections::HashSet<_> =
                graph.initializer.iter().map(|i| i.name.as_str()).collect();
            let mut inputs = std::collections::HashMap::new();
            for input in graph.input.iter() {
                use candle_onnx::onnx::tensor_proto::DataType;
                if constants.contains(input.name.as_str()) {
                    continue;
                }

                let type_ = input.r#type.as_ref().expect("no type for input");
                let type_ = type_.value.as_ref().expect("no type.value for input");
                let value = match type_ {
                    candle_onnx::onnx::type_proto::Value::TensorType(tt) => {
                        let dt = match DataType::try_from(tt.elem_type) {
                            Ok(dt) => match candle_onnx::dtype(dt) {
                                Some(dt) => dt,
                                None => {
                                    anyhow::bail!(
                                        "unsupported 'value' data-type {dt:?} for {}",
                                        input.name
                                    )
                                }
                            },
                            type_ => anyhow::bail!("unsupported input type {type_:?}"),
                        };
                        let shape = tt.shape.as_ref().expect("no tensortype.shape for input");
                        let dims = shape
                                .dim
                                .iter()
                                .map(|dim| match dim.value.as_ref().expect("no dim value") {
                                    candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
                                    candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => Ok(42),
                                })
                                .collect::<Result<Vec<usize>>>()?;
                        Tensor::zeros(dims, dt, &Device::Cpu)?
                    }
                    type_ => anyhow::bail!("unsupported input type {type_:?}"),
                };
                println!("input {}: {value:?}", input.name);
                inputs.insert(input.name.clone(), value);
            }
            let outputs = candle_onnx::simple_eval(&model, inputs)?;
            for (name, value) in outputs.iter() {
                println!("output {name}: {value:?}")
            }
        }
    }
    Ok(())
}