diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-11-04 10:02:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-04 10:02:47 +0100 |
commit | bc9a1bf2399243c659b1e902b14e8572a12ec15b (patch) | |
tree | 3ed7f98114a0bc36d26a7bc99e144406728b1571 /candle-onnx/examples/onnx_basics.rs | |
parent | f7c957d64f09ca6569ab6db265664fc192113972 (diff) | |
download | candle-bc9a1bf2399243c659b1e902b14e8572a12ec15b.tar.gz candle-bc9a1bf2399243c659b1e902b14e8572a12ec15b.tar.bz2 candle-bc9a1bf2399243c659b1e902b14e8572a12ec15b.zip |
Improve the ONNX basic example + bugfixes (#1266)
* Generate some zeros tensor in the onnx simple-eval example.
* Fix the casting operation.
* Support more ops.
* Handle reshape.
* Concat.
* Softmax.
Diffstat (limited to 'candle-onnx/examples/onnx_basics.rs')
-rw-r--r-- | candle-onnx/examples/onnx_basics.rs | 36 |
1 files changed, 33 insertions, 3 deletions
diff --git a/candle-onnx/examples/onnx_basics.rs b/candle-onnx/examples/onnx_basics.rs index b91cbee6..2c52e68e 100644 --- a/candle-onnx/examples/onnx_basics.rs +++ b/candle-onnx/examples/onnx_basics.rs @@ -41,9 +41,39 @@ pub fn main() -> Result<()> { .unwrap() .input .iter() - .map(|name| { - let value = Tensor::new(&[-3.2, 2.7], &Device::Cpu)?; - Ok((name.name.clone(), value)) + .map(|input| { + use candle_onnx::onnx::tensor_proto::DataType; + + let type_ = input.r#type.as_ref().expect("no type for input"); + let type_ = type_.value.as_ref().expect("no type.value for input"); + let value = match type_ { + candle_onnx::onnx::type_proto::Value::TensorType(tt) => { + let dt = match DataType::try_from(tt.elem_type) { + Ok(dt) => match candle_onnx::dtype(dt) { + Some(dt) => dt, + None => { + anyhow::bail!( + "unsupported 'value' data-type {dt:?} for {}", + input.name + ) + } + }, + type_ => anyhow::bail!("unsupported input type {type_:?}"), + }; + let shape = tt.shape.as_ref().expect("no tensortype.shape for input"); + let dims = shape + .dim + .iter() + .map(|dim| match dim.value.as_ref().expect("no dim value") { + candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize), + candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => anyhow::bail!("DimParam is unsupported for input {}", input.name), + }) + .collect::<Result<Vec<usize>>>()?; + Tensor::zeros(dims, dt, &Device::Cpu)? + } + type_ => anyhow::bail!("unsupported input type {type_:?}"), + }; + Ok::<_, anyhow::Error>((input.name.clone(), value)) }) .collect::<Result<_>>()?; let outputs = candle_onnx::simple_eval(&model, inputs)?; |