summaryrefslogtreecommitdiff
path: root/candle-onnx/examples/onnx_basics.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-11-04 10:02:47 +0100
committerGitHub <noreply@github.com>2023-11-04 10:02:47 +0100
commitbc9a1bf2399243c659b1e902b14e8572a12ec15b (patch)
tree3ed7f98114a0bc36d26a7bc99e144406728b1571 /candle-onnx/examples/onnx_basics.rs
parentf7c957d64f09ca6569ab6db265664fc192113972 (diff)
downloadcandle-bc9a1bf2399243c659b1e902b14e8572a12ec15b.tar.gz
candle-bc9a1bf2399243c659b1e902b14e8572a12ec15b.tar.bz2
candle-bc9a1bf2399243c659b1e902b14e8572a12ec15b.zip
Improve the ONNX basic example + bugfixes (#1266)
* Generate some zeros tensor in the onnx simple-eval example. * Fix the casting operation. * Support more ops. * Handle reshape. * Concat. * Softmax.
Diffstat (limited to 'candle-onnx/examples/onnx_basics.rs')
-rw-r--r--candle-onnx/examples/onnx_basics.rs36
1 files changed, 33 insertions, 3 deletions
diff --git a/candle-onnx/examples/onnx_basics.rs b/candle-onnx/examples/onnx_basics.rs
index b91cbee6..2c52e68e 100644
--- a/candle-onnx/examples/onnx_basics.rs
+++ b/candle-onnx/examples/onnx_basics.rs
@@ -41,9 +41,39 @@ pub fn main() -> Result<()> {
.unwrap()
.input
.iter()
- .map(|name| {
- let value = Tensor::new(&[-3.2, 2.7], &Device::Cpu)?;
- Ok((name.name.clone(), value))
+ .map(|input| {
+ use candle_onnx::onnx::tensor_proto::DataType;
+
+ let type_ = input.r#type.as_ref().expect("no type for input");
+ let type_ = type_.value.as_ref().expect("no type.value for input");
+ let value = match type_ {
+ candle_onnx::onnx::type_proto::Value::TensorType(tt) => {
+ let dt = match DataType::try_from(tt.elem_type) {
+ Ok(dt) => match candle_onnx::dtype(dt) {
+ Some(dt) => dt,
+ None => {
+ anyhow::bail!(
+ "unsupported 'value' data-type {dt:?} for {}",
+ input.name
+ )
+ }
+ },
+ type_ => anyhow::bail!("unsupported input type {type_:?}"),
+ };
+ let shape = tt.shape.as_ref().expect("no tensortype.shape for input");
+ let dims = shape
+ .dim
+ .iter()
+ .map(|dim| match dim.value.as_ref().expect("no dim value") {
+ candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
+ candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => anyhow::bail!("DimParam is unsupported for input {}", input.name),
+ })
+ .collect::<Result<Vec<usize>>>()?;
+ Tensor::zeros(dims, dt, &Device::Cpu)?
+ }
+ type_ => anyhow::bail!("unsupported input type {type_:?}"),
+ };
+ Ok::<_, anyhow::Error>((input.name.clone(), value))
})
.collect::<Result<_>>()?;
let outputs = candle_onnx::simple_eval(&model, inputs)?;