diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-11-04 08:34:24 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-04 08:34:24 +0100 |
commit | f7c957d64f09ca6569ab6db265664fc192113972 (patch) | |
tree | ad74bd0c1abe7eb9549fdce89610cc8cf6d9a385 /candle-onnx | |
parent | 8cbb9d0e6ce57a8dbfc685f3121ed9d01b02726e (diff) | |
download | candle-f7c957d64f09ca6569ab6db265664fc192113972.tar.gz candle-f7c957d64f09ca6569ab6db265664fc192113972.tar.bz2 candle-f7c957d64f09ca6569ab6db265664fc192113972.zip |
ONNX casting support. (#1265)
* ONNX casting support.
* Handle tensor constants.
* Bugfix the binary ops.
Diffstat (limited to 'candle-onnx')
-rw-r--r-- | candle-onnx/src/eval.rs | 104 |
1 files changed, 94 insertions, 10 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index fe112fdd..b9a0d9da 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -1,5 +1,5 @@ use crate::onnx; -use candle::{Result, Tensor}; +use candle::{bail, DType, Device, Result, Tensor}; use std::collections::HashMap; pub type Value = Tensor; @@ -13,8 +13,9 @@ pub fn simple_eval( model: &onnx::ModelProto, inputs: HashMap<String, Value>, ) -> Result<HashMap<String, Value>> { + use crate::onnx::attribute_proto::AttributeType; let graph = match &model.graph { - None => candle::bail!("no graph defined in proto"), + None => bail!("no graph defined in proto"), Some(graph) => graph, }; // TODO: validate the inputs. @@ -23,37 +24,37 @@ pub fn simple_eval( for node in graph.node.iter() { let get = |input_name: &str| match values.get(input_name) { Some(value) => Ok(value), - None => candle::bail!("cannot find {input_name} for op {}", node.name), + None => bail!("cannot find {input_name} for op {}", node.name), }; // TODO: Validate node.input for each operator. match node.op_type.as_str() { "Add" => { let input0 = get(&node.input[0])?; - let input1 = get(&node.input[0])?; + let input1 = get(&node.input[1])?; let output = input0.broadcast_add(input1)?; values.insert(node.output[0].clone(), output); } "Sub" => { let input0 = get(&node.input[0])?; - let input1 = get(&node.input[0])?; + let input1 = get(&node.input[1])?; let output = input0.broadcast_sub(input1)?; values.insert(node.output[0].clone(), output); } "Mul" => { let input0 = get(&node.input[0])?; - let input1 = get(&node.input[0])?; + let input1 = get(&node.input[1])?; let output = input0.broadcast_mul(input1)?; values.insert(node.output[0].clone(), output); } "Div" => { let input0 = get(&node.input[0])?; - let input1 = get(&node.input[0])?; + let input1 = get(&node.input[1])?; let output = input0.broadcast_div(input1)?; values.insert(node.output[0].clone(), output); } "MatMul" => { let input0 = get(&node.input[0])?; - let input1 = get(&node.input[0])?; + let input1 = get(&node.input[1])?; let output = input0.broadcast_matmul(input1)?; values.insert(node.output[0].clone(), output); } @@ -67,14 +68,97 @@ pub fn simple_eval( let output = input.relu()?; values.insert(node.output[0].clone(), output); } - op_type => candle::bail!("unsupported op_type {op_type} for op {}", node.name), + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Constant + "Constant" => { + let value = match node.attribute.iter().find(|attr| attr.name == "value") { + None => { + // TODO: support sparse_value etc. + bail!("cannot find 'value' attr in 'Constant' for {}", node.name) + } + Some(value) => value, + }; + let output = match value.r#type() { + AttributeType::Tensor => { + use crate::onnx::tensor_proto::DataType; + let t = value.t.as_ref().unwrap(); + let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect(); + match DataType::try_from(t.data_type) { + Ok(DataType::Uint8) => Tensor::from_raw_buffer( + t.raw_data.as_slice(), + DType::U8, + dims.as_slice(), + &Device::Cpu, + )?, + Ok(DataType::Uint32) => Tensor::from_raw_buffer( + t.raw_data.as_slice(), + DType::U32, + dims.as_slice(), + &Device::Cpu, + )?, + Ok(DataType::Int64) => Tensor::from_raw_buffer( + t.raw_data.as_slice(), + DType::I64, + dims.as_slice(), + &Device::Cpu, + )?, + Ok(DataType::Float16) => Tensor::from_raw_buffer( + t.raw_data.as_slice(), + DType::F16, + dims.as_slice(), + &Device::Cpu, + )?, + Ok(DataType::Float) => Tensor::from_raw_buffer( + t.raw_data.as_slice(), + DType::F32, + dims.as_slice(), + &Device::Cpu, + )?, + Ok(DataType::Double) => Tensor::from_raw_buffer( + t.raw_data.as_slice(), + DType::F64, + dims.as_slice(), + &Device::Cpu, + )?, + Ok(dt) => { + bail!("unsupported 'value' data-type {dt:?} for {}", node.name) + } + Err(_) => { + bail!( + "unsupported 'value' data-type {} for {}", + t.data_type, + node.name + ) + } + } + } + rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name), + }; + values.insert(node.output[0].clone(), output); + } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast + "Cast" => { + let input = get(&node.input[0])?; + let dtype = match node.attribute.iter().find(|attr| attr.name == "to") { + None => { + bail!("cannot find the 'to' attribute in 'Cast' for {}", node.name) + } + Some(dtype) => match dtype.r#type() { + AttributeType::Floats => candle::DType::F32, + AttributeType::Int => candle::DType::I64, + rtype => bail!("unsupported 'to' type {rtype:?} for {}", node.name), + }, + }; + let output = input.to_dtype(dtype)?; + values.insert(node.output[0].clone(), output); + } + op_type => bail!("unsupported op_type {op_type} for op {}", node.name), } } graph .output .iter() .map(|output| match values.remove(&output.name) { - None => candle::bail!("cannot find output {}", output.name), + None => bail!("cannot find output {}", output.name), Some(value) => Ok((output.name.clone(), value)), }) .collect() |