diff options
Diffstat (limited to 'candle-onnx/examples/onnx_basics.rs')
-rw-r--r-- | candle-onnx/examples/onnx_basics.rs | 36 |
1 files changed, 33 insertions, 3 deletions
diff --git a/candle-onnx/examples/onnx_basics.rs b/candle-onnx/examples/onnx_basics.rs index b91cbee6..2c52e68e 100644 --- a/candle-onnx/examples/onnx_basics.rs +++ b/candle-onnx/examples/onnx_basics.rs @@ -41,9 +41,39 @@ pub fn main() -> Result<()> { .unwrap() .input .iter() - .map(|name| { - let value = Tensor::new(&[-3.2, 2.7], &Device::Cpu)?; - Ok((name.name.clone(), value)) + .map(|input| { + use candle_onnx::onnx::tensor_proto::DataType; + + let type_ = input.r#type.as_ref().expect("no type for input"); + let type_ = type_.value.as_ref().expect("no type.value for input"); + let value = match type_ { + candle_onnx::onnx::type_proto::Value::TensorType(tt) => { + let dt = match DataType::try_from(tt.elem_type) { + Ok(dt) => match candle_onnx::dtype(dt) { + Some(dt) => dt, + None => { + anyhow::bail!( + "unsupported 'value' data-type {dt:?} for {}", + input.name + ) + } + }, + type_ => anyhow::bail!("unsupported input type {type_:?}"), + }; + let shape = tt.shape.as_ref().expect("no tensortype.shape for input"); + let dims = shape + .dim + .iter() + .map(|dim| match dim.value.as_ref().expect("no dim value") { + candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize), + candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => anyhow::bail!("DimParam is unsupported for input {}", input.name), + }) + .collect::<Result<Vec<usize>>>()?; + Tensor::zeros(dims, dt, &Device::Cpu)? + } + type_ => anyhow::bail!("unsupported input type {type_:?}"), + }; + Ok::<_, anyhow::Error>((input.name.clone(), value)) }) .collect::<Result<_>>()?; let outputs = candle_onnx::simple_eval(&model, inputs)?; |