summaryrefslogtreecommitdiff
path: root/candle-onnx
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-11-04 10:02:47 +0100
committerGitHub <noreply@github.com>2023-11-04 10:02:47 +0100
commitbc9a1bf2399243c659b1e902b14e8572a12ec15b (patch)
tree3ed7f98114a0bc36d26a7bc99e144406728b1571 /candle-onnx
parentf7c957d64f09ca6569ab6db265664fc192113972 (diff)
downloadcandle-bc9a1bf2399243c659b1e902b14e8572a12ec15b.tar.gz
candle-bc9a1bf2399243c659b1e902b14e8572a12ec15b.tar.bz2
candle-bc9a1bf2399243c659b1e902b14e8572a12ec15b.zip
Improve the ONNX basic example + bugfixes (#1266)
* Generate some zeros tensor in the onnx simple-eval example. * Fix the casting operation. * Support more ops. * Handle reshape. * Concat. * Softmax.
Diffstat (limited to 'candle-onnx')
-rw-r--r--candle-onnx/examples/onnx_basics.rs36
-rw-r--r--candle-onnx/src/eval.rs204
-rw-r--r--candle-onnx/src/lib.rs2
3 files changed, 190 insertions, 52 deletions
diff --git a/candle-onnx/examples/onnx_basics.rs b/candle-onnx/examples/onnx_basics.rs
index b91cbee6..2c52e68e 100644
--- a/candle-onnx/examples/onnx_basics.rs
+++ b/candle-onnx/examples/onnx_basics.rs
@@ -41,9 +41,39 @@ pub fn main() -> Result<()> {
.unwrap()
.input
.iter()
- .map(|name| {
- let value = Tensor::new(&[-3.2, 2.7], &Device::Cpu)?;
- Ok((name.name.clone(), value))
+ .map(|input| {
+ use candle_onnx::onnx::tensor_proto::DataType;
+
+ let type_ = input.r#type.as_ref().expect("no type for input");
+ let type_ = type_.value.as_ref().expect("no type.value for input");
+ let value = match type_ {
+ candle_onnx::onnx::type_proto::Value::TensorType(tt) => {
+ let dt = match DataType::try_from(tt.elem_type) {
+ Ok(dt) => match candle_onnx::dtype(dt) {
+ Some(dt) => dt,
+ None => {
+ anyhow::bail!(
+ "unsupported 'value' data-type {dt:?} for {}",
+ input.name
+ )
+ }
+ },
+ type_ => anyhow::bail!("unsupported input type {type_:?}"),
+ };
+ let shape = tt.shape.as_ref().expect("no tensortype.shape for input");
+ let dims = shape
+ .dim
+ .iter()
+ .map(|dim| match dim.value.as_ref().expect("no dim value") {
+ candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
+ candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => anyhow::bail!("DimParam is unsupported for input {}", input.name),
+ })
+ .collect::<Result<Vec<usize>>>()?;
+ Tensor::zeros(dims, dt, &Device::Cpu)?
+ }
+ type_ => anyhow::bail!("unsupported input type {type_:?}"),
+ };
+ Ok::<_, anyhow::Error>((input.name.clone(), value))
})
.collect::<Result<_>>()?;
let outputs = candle_onnx::simple_eval(&model, inputs)?;
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index b9a0d9da..2a80f8c1 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -1,9 +1,22 @@
use crate::onnx;
+use crate::onnx::tensor_proto::DataType;
use candle::{bail, DType, Device, Result, Tensor};
use std::collections::HashMap;
pub type Value = Tensor;
+pub fn dtype(dt: DataType) -> Option<DType> {
+ match dt {
+ DataType::Uint8 => Some(DType::U8),
+ DataType::Uint32 => Some(DType::U32),
+ DataType::Int64 => Some(DType::I64),
+ DataType::Float16 => Some(DType::F16),
+ DataType::Float => Some(DType::F32),
+ DataType::Double => Some(DType::F64),
+ _ => None,
+ }
+}
+
// This function provides a direct evaluation of the proto.
// Longer-term, we should first convert the proto to an intermediate representation of the compute
// graph so as to make multiple evaluations more efficient.
@@ -26,6 +39,26 @@ pub fn simple_eval(
Some(value) => Ok(value),
None => bail!("cannot find {input_name} for op {}", node.name),
};
+ let get_attr_i = |name: &str| match node.attribute.iter().find(|attr| attr.name == name) {
+ None => {
+ bail!(
+ "cannot find the '{name}' attribute in '{}' for {}",
+ node.op_type,
+ node.name
+ )
+ }
+ Some(dt) => {
+ match dt.r#type() {
+ AttributeType::Int => (),
+ rtype => bail!(
+ "unsupported type {rtype:?} for '{name}' attribute in '{}' for {}",
+ node.op_type,
+ node.name
+ ),
+ }
+ Ok(dt.i)
+ }
+ };
// TODO: Validate node.input for each operator.
match node.op_type.as_str() {
"Add" => {
@@ -52,12 +85,114 @@ pub fn simple_eval(
let output = input0.broadcast_div(input1)?;
values.insert(node.output[0].clone(), output);
}
+ "Equal" => {
+ let input0 = get(&node.input[0])?;
+ let input1 = get(&node.input[1])?;
+ let output = input0.eq(input1)?;
+ values.insert(node.output[0].clone(), output);
+ }
"MatMul" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_matmul(input1)?;
values.insert(node.output[0].clone(), output);
}
+ "Reshape" => {
+ let input0 = get(&node.input[0])?;
+ let input1 = get(&node.input[1])?.to_vec1::<i64>()?;
+ // TODO: Check that there is at most a single -1, handle other neg values.
+ let input1 = input1
+ .iter()
+ .map(|&v| {
+ if v == -1 {
+ input0.elem_count()
+ } else {
+ v as usize
+ }
+ })
+ .collect::<Vec<usize>>();
+ let output = input0.reshape(input1)?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "Softmax" => {
+ let input = get(&node.input[0])?;
+ let output = match get_attr_i("axis") {
+ Err(_) => candle_nn::ops::softmax_last_dim(input)?,
+ Ok(axis) => {
+ let num_axis = input.rank() as i64;
+ let axis = if axis >= 0 {
+ axis as usize
+ } else if axis < -num_axis {
+ bail!("wrong axis in concat {axis} for shape {:?}", input.shape())
+ } else {
+ (num_axis - axis) as usize
+ };
+ candle_nn::ops::softmax(input, axis)?
+ }
+ };
+ values.insert(node.output[0].clone(), output);
+ }
+ "Concat" => {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat
+ let inputs = node
+ .input
+ .iter()
+ .map(|n| Ok(get(n.as_str())?.clone()))
+ .collect::<Result<Vec<Value>>>()?;
+ let axis = get_attr_i("axis")?;
+ let num_axis = if inputs.is_empty() {
+ bail!("empty concat")
+ } else {
+ inputs[0].rank() as i64
+ };
+ let axis = if axis >= 0 {
+ axis as usize
+ } else if axis < -num_axis {
+ bail!(
+ "wrong axis in concat {axis} for shape {:?}",
+ inputs[0].shape()
+ )
+ } else {
+ (num_axis - axis) as usize
+ };
+ let output = Tensor::cat(&inputs, axis)?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "Abs" => {
+ let input = get(&node.input[0])?;
+ let output = input.abs()?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "Cos" => {
+ let input = get(&node.input[0])?;
+ let output = input.cos()?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "Sin" => {
+ let input = get(&node.input[0])?;
+ let output = input.sin()?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "Neg" => {
+ let input = get(&node.input[0])?;
+ let output = input.neg()?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "Erf" => {
+ let input = get(&node.input[0])?;
+ let output = input.erf()?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "Tanh" => {
+ let input = get(&node.input[0])?;
+ let output = input.tanh()?;
+ values.insert(node.output[0].clone(), output);
+ }
+ "Sigmoid" => {
+ let input = get(&node.input[0])?;
+ let output = candle_nn::ops::sigmoid(input)?;
+ values.insert(node.output[0].clone(), output);
+ }
"Gelu" => {
let input = get(&node.input[0])?;
let output = input.gelu_erf()?;
@@ -79,49 +214,20 @@ pub fn simple_eval(
};
let output = match value.r#type() {
AttributeType::Tensor => {
- use crate::onnx::tensor_proto::DataType;
let t = value.t.as_ref().unwrap();
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
match DataType::try_from(t.data_type) {
- Ok(DataType::Uint8) => Tensor::from_raw_buffer(
- t.raw_data.as_slice(),
- DType::U8,
- dims.as_slice(),
- &Device::Cpu,
- )?,
- Ok(DataType::Uint32) => Tensor::from_raw_buffer(
- t.raw_data.as_slice(),
- DType::U32,
- dims.as_slice(),
- &Device::Cpu,
- )?,
- Ok(DataType::Int64) => Tensor::from_raw_buffer(
- t.raw_data.as_slice(),
- DType::I64,
- dims.as_slice(),
- &Device::Cpu,
- )?,
- Ok(DataType::Float16) => Tensor::from_raw_buffer(
- t.raw_data.as_slice(),
- DType::F16,
- dims.as_slice(),
- &Device::Cpu,
- )?,
- Ok(DataType::Float) => Tensor::from_raw_buffer(
- t.raw_data.as_slice(),
- DType::F32,
- dims.as_slice(),
- &Device::Cpu,
- )?,
- Ok(DataType::Double) => Tensor::from_raw_buffer(
- t.raw_data.as_slice(),
- DType::F64,
- dims.as_slice(),
- &Device::Cpu,
- )?,
- Ok(dt) => {
- bail!("unsupported 'value' data-type {dt:?} for {}", node.name)
- }
+ Ok(dt) => match dtype(dt) {
+ Some(dt) => Tensor::from_raw_buffer(
+ t.raw_data.as_slice(),
+ dt,
+ dims.as_slice(),
+ &Device::Cpu,
+ )?,
+ None => {
+ bail!("unsupported 'value' data-type {dt:?} for {}", node.name)
+ }
+ },
Err(_) => {
bail!(
"unsupported 'value' data-type {} for {}",
@@ -138,15 +244,17 @@ pub fn simple_eval(
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
"Cast" => {
let input = get(&node.input[0])?;
- let dtype = match node.attribute.iter().find(|attr| attr.name == "to") {
- None => {
- bail!("cannot find the 'to' attribute in 'Cast' for {}", node.name)
- }
- Some(dtype) => match dtype.r#type() {
- AttributeType::Floats => candle::DType::F32,
- AttributeType::Int => candle::DType::I64,
- rtype => bail!("unsupported 'to' type {rtype:?} for {}", node.name),
+ let dt = get_attr_i("to")?;
+ let dtype = match DataType::try_from(dt as i32) {
+ Ok(dt) => match dtype(dt) {
+ Some(dt) => dt,
+ None => {
+ bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
+ }
},
+ Err(_) => {
+ bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
+ }
};
let output = input.to_dtype(dtype)?;
values.insert(node.output[0].clone(), output);
diff --git a/candle-onnx/src/lib.rs b/candle-onnx/src/lib.rs
index 3b36c4cf..1002a2c8 100644
--- a/candle-onnx/src/lib.rs
+++ b/candle-onnx/src/lib.rs
@@ -6,7 +6,7 @@ pub mod onnx {
}
mod eval;
-pub use eval::simple_eval;
+pub use eval::{dtype, simple_eval};
pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
let buf = std::fs::read(p)?;