diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-11-04 10:02:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-04 10:02:47 +0100 |
commit | bc9a1bf2399243c659b1e902b14e8572a12ec15b (patch) | |
tree | 3ed7f98114a0bc36d26a7bc99e144406728b1571 /candle-onnx | |
parent | f7c957d64f09ca6569ab6db265664fc192113972 (diff) | |
download | candle-bc9a1bf2399243c659b1e902b14e8572a12ec15b.tar.gz candle-bc9a1bf2399243c659b1e902b14e8572a12ec15b.tar.bz2 candle-bc9a1bf2399243c659b1e902b14e8572a12ec15b.zip |
Improve the ONNX basic example + bugfixes (#1266)
* Generate some zeros tensor in the onnx simple-eval example.
* Fix the casting operation.
* Support more ops.
* Handle reshape.
* Concat.
* Softmax.
Diffstat (limited to 'candle-onnx')
-rw-r--r-- | candle-onnx/examples/onnx_basics.rs | 36 | ||||
-rw-r--r-- | candle-onnx/src/eval.rs | 204 | ||||
-rw-r--r-- | candle-onnx/src/lib.rs | 2 |
3 files changed, 190 insertions, 52 deletions
diff --git a/candle-onnx/examples/onnx_basics.rs b/candle-onnx/examples/onnx_basics.rs index b91cbee6..2c52e68e 100644 --- a/candle-onnx/examples/onnx_basics.rs +++ b/candle-onnx/examples/onnx_basics.rs @@ -41,9 +41,39 @@ pub fn main() -> Result<()> { .unwrap() .input .iter() - .map(|name| { - let value = Tensor::new(&[-3.2, 2.7], &Device::Cpu)?; - Ok((name.name.clone(), value)) + .map(|input| { + use candle_onnx::onnx::tensor_proto::DataType; + + let type_ = input.r#type.as_ref().expect("no type for input"); + let type_ = type_.value.as_ref().expect("no type.value for input"); + let value = match type_ { + candle_onnx::onnx::type_proto::Value::TensorType(tt) => { + let dt = match DataType::try_from(tt.elem_type) { + Ok(dt) => match candle_onnx::dtype(dt) { + Some(dt) => dt, + None => { + anyhow::bail!( + "unsupported 'value' data-type {dt:?} for {}", + input.name + ) + } + }, + type_ => anyhow::bail!("unsupported input type {type_:?}"), + }; + let shape = tt.shape.as_ref().expect("no tensortype.shape for input"); + let dims = shape + .dim + .iter() + .map(|dim| match dim.value.as_ref().expect("no dim value") { + candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize), + candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => anyhow::bail!("DimParam is unsupported for input {}", input.name), + }) + .collect::<Result<Vec<usize>>>()?; + Tensor::zeros(dims, dt, &Device::Cpu)? + } + type_ => anyhow::bail!("unsupported input type {type_:?}"), + }; + Ok::<_, anyhow::Error>((input.name.clone(), value)) }) .collect::<Result<_>>()?; let outputs = candle_onnx::simple_eval(&model, inputs)?; diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index b9a0d9da..2a80f8c1 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -1,9 +1,22 @@ use crate::onnx; +use crate::onnx::tensor_proto::DataType; use candle::{bail, DType, Device, Result, Tensor}; use std::collections::HashMap; pub type Value = Tensor; +pub fn dtype(dt: DataType) -> Option<DType> { + match dt { + DataType::Uint8 => Some(DType::U8), + DataType::Uint32 => Some(DType::U32), + DataType::Int64 => Some(DType::I64), + DataType::Float16 => Some(DType::F16), + DataType::Float => Some(DType::F32), + DataType::Double => Some(DType::F64), + _ => None, + } +} + // This function provides a direct evaluation of the proto. // Longer-term, we should first convert the proto to an intermediate representation of the compute // graph so as to make multiple evaluations more efficient. @@ -26,6 +39,26 @@ pub fn simple_eval( Some(value) => Ok(value), None => bail!("cannot find {input_name} for op {}", node.name), }; + let get_attr_i = |name: &str| match node.attribute.iter().find(|attr| attr.name == name) { + None => { + bail!( + "cannot find the '{name}' attribute in '{}' for {}", + node.op_type, + node.name + ) + } + Some(dt) => { + match dt.r#type() { + AttributeType::Int => (), + rtype => bail!( + "unsupported type {rtype:?} for '{name}' attribute in '{}' for {}", + node.op_type, + node.name + ), + } + Ok(dt.i) + } + }; // TODO: Validate node.input for each operator. match node.op_type.as_str() { "Add" => { @@ -52,12 +85,114 @@ pub fn simple_eval( let output = input0.broadcast_div(input1)?; values.insert(node.output[0].clone(), output); } + "Equal" => { + let input0 = get(&node.input[0])?; + let input1 = get(&node.input[1])?; + let output = input0.eq(input1)?; + values.insert(node.output[0].clone(), output); + } "MatMul" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; let output = input0.broadcast_matmul(input1)?; values.insert(node.output[0].clone(), output); } + "Reshape" => { + let input0 = get(&node.input[0])?; + let input1 = get(&node.input[1])?.to_vec1::<i64>()?; + // TODO: Check that there is at most a single -1, handle other neg values. + let input1 = input1 + .iter() + .map(|&v| { + if v == -1 { + input0.elem_count() + } else { + v as usize + } + }) + .collect::<Vec<usize>>(); + let output = input0.reshape(input1)?; + values.insert(node.output[0].clone(), output); + } + "Softmax" => { + let input = get(&node.input[0])?; + let output = match get_attr_i("axis") { + Err(_) => candle_nn::ops::softmax_last_dim(input)?, + Ok(axis) => { + let num_axis = input.rank() as i64; + let axis = if axis >= 0 { + axis as usize + } else if axis < -num_axis { + bail!("wrong axis in concat {axis} for shape {:?}", input.shape()) + } else { + (num_axis - axis) as usize + }; + candle_nn::ops::softmax(input, axis)? + } + }; + values.insert(node.output[0].clone(), output); + } + "Concat" => { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat + let inputs = node + .input + .iter() + .map(|n| Ok(get(n.as_str())?.clone())) + .collect::<Result<Vec<Value>>>()?; + let axis = get_attr_i("axis")?; + let num_axis = if inputs.is_empty() { + bail!("empty concat") + } else { + inputs[0].rank() as i64 + }; + let axis = if axis >= 0 { + axis as usize + } else if axis < -num_axis { + bail!( + "wrong axis in concat {axis} for shape {:?}", + inputs[0].shape() + ) + } else { + (num_axis - axis) as usize + }; + let output = Tensor::cat(&inputs, axis)?; + values.insert(node.output[0].clone(), output); + } + "Abs" => { + let input = get(&node.input[0])?; + let output = input.abs()?; + values.insert(node.output[0].clone(), output); + } + "Cos" => { + let input = get(&node.input[0])?; + let output = input.cos()?; + values.insert(node.output[0].clone(), output); + } + "Sin" => { + let input = get(&node.input[0])?; + let output = input.sin()?; + values.insert(node.output[0].clone(), output); + } + "Neg" => { + let input = get(&node.input[0])?; + let output = input.neg()?; + values.insert(node.output[0].clone(), output); + } + "Erf" => { + let input = get(&node.input[0])?; + let output = input.erf()?; + values.insert(node.output[0].clone(), output); + } + "Tanh" => { + let input = get(&node.input[0])?; + let output = input.tanh()?; + values.insert(node.output[0].clone(), output); + } + "Sigmoid" => { + let input = get(&node.input[0])?; + let output = candle_nn::ops::sigmoid(input)?; + values.insert(node.output[0].clone(), output); + } "Gelu" => { let input = get(&node.input[0])?; let output = input.gelu_erf()?; @@ -79,49 +214,20 @@ pub fn simple_eval( }; let output = match value.r#type() { AttributeType::Tensor => { - use crate::onnx::tensor_proto::DataType; let t = value.t.as_ref().unwrap(); let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect(); match DataType::try_from(t.data_type) { - Ok(DataType::Uint8) => Tensor::from_raw_buffer( - t.raw_data.as_slice(), - DType::U8, - dims.as_slice(), - &Device::Cpu, - )?, - Ok(DataType::Uint32) => Tensor::from_raw_buffer( - t.raw_data.as_slice(), - DType::U32, - dims.as_slice(), - &Device::Cpu, - )?, - Ok(DataType::Int64) => Tensor::from_raw_buffer( - t.raw_data.as_slice(), - DType::I64, - dims.as_slice(), - &Device::Cpu, - )?, - Ok(DataType::Float16) => Tensor::from_raw_buffer( - t.raw_data.as_slice(), - DType::F16, - dims.as_slice(), - &Device::Cpu, - )?, - Ok(DataType::Float) => Tensor::from_raw_buffer( - t.raw_data.as_slice(), - DType::F32, - dims.as_slice(), - &Device::Cpu, - )?, - Ok(DataType::Double) => Tensor::from_raw_buffer( - t.raw_data.as_slice(), - DType::F64, - dims.as_slice(), - &Device::Cpu, - )?, - Ok(dt) => { - bail!("unsupported 'value' data-type {dt:?} for {}", node.name) - } + Ok(dt) => match dtype(dt) { + Some(dt) => Tensor::from_raw_buffer( + t.raw_data.as_slice(), + dt, + dims.as_slice(), + &Device::Cpu, + )?, + None => { + bail!("unsupported 'value' data-type {dt:?} for {}", node.name) + } + }, Err(_) => { bail!( "unsupported 'value' data-type {} for {}", @@ -138,15 +244,17 @@ pub fn simple_eval( // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast "Cast" => { let input = get(&node.input[0])?; - let dtype = match node.attribute.iter().find(|attr| attr.name == "to") { - None => { - bail!("cannot find the 'to' attribute in 'Cast' for {}", node.name) - } - Some(dtype) => match dtype.r#type() { - AttributeType::Floats => candle::DType::F32, - AttributeType::Int => candle::DType::I64, - rtype => bail!("unsupported 'to' type {rtype:?} for {}", node.name), + let dt = get_attr_i("to")?; + let dtype = match DataType::try_from(dt as i32) { + Ok(dt) => match dtype(dt) { + Some(dt) => dt, + None => { + bail!("unsupported 'to' value {dt:?} for cast {}", node.name) + } }, + Err(_) => { + bail!("unsupported 'to' value {dt:?} for cast {}", node.name) + } }; let output = input.to_dtype(dtype)?; values.insert(node.output[0].clone(), output); diff --git a/candle-onnx/src/lib.rs b/candle-onnx/src/lib.rs index 3b36c4cf..1002a2c8 100644 --- a/candle-onnx/src/lib.rs +++ b/candle-onnx/src/lib.rs @@ -6,7 +6,7 @@ pub mod onnx { } mod eval; -pub use eval::simple_eval; +pub use eval::{dtype, simple_eval}; pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> { let buf = std::fs::read(p)?; |