summaryrefslogtreecommitdiff
path: root/candle-onnx
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-11-04 08:34:24 +0100
committerGitHub <noreply@github.com>2023-11-04 08:34:24 +0100
commitf7c957d64f09ca6569ab6db265664fc192113972 (patch)
treead74bd0c1abe7eb9549fdce89610cc8cf6d9a385 /candle-onnx
parent8cbb9d0e6ce57a8dbfc685f3121ed9d01b02726e (diff)
downloadcandle-f7c957d64f09ca6569ab6db265664fc192113972.tar.gz
candle-f7c957d64f09ca6569ab6db265664fc192113972.tar.bz2
candle-f7c957d64f09ca6569ab6db265664fc192113972.zip
ONNX casting support. (#1265)
* ONNX casting support. * Handle tensor constants. * Bugfix the binary ops.
Diffstat (limited to 'candle-onnx')
-rw-r--r--candle-onnx/src/eval.rs104
1 files changed, 94 insertions, 10 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index fe112fdd..b9a0d9da 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -1,5 +1,5 @@
use crate::onnx;
-use candle::{Result, Tensor};
+use candle::{bail, DType, Device, Result, Tensor};
use std::collections::HashMap;
pub type Value = Tensor;
@@ -13,8 +13,9 @@ pub fn simple_eval(
model: &onnx::ModelProto,
inputs: HashMap<String, Value>,
) -> Result<HashMap<String, Value>> {
+ use crate::onnx::attribute_proto::AttributeType;
let graph = match &model.graph {
- None => candle::bail!("no graph defined in proto"),
+ None => bail!("no graph defined in proto"),
Some(graph) => graph,
};
// TODO: validate the inputs.
@@ -23,37 +24,37 @@ pub fn simple_eval(
for node in graph.node.iter() {
let get = |input_name: &str| match values.get(input_name) {
Some(value) => Ok(value),
- None => candle::bail!("cannot find {input_name} for op {}", node.name),
+ None => bail!("cannot find {input_name} for op {}", node.name),
};
// TODO: Validate node.input for each operator.
match node.op_type.as_str() {
"Add" => {
let input0 = get(&node.input[0])?;
- let input1 = get(&node.input[0])?;
+ let input1 = get(&node.input[1])?;
let output = input0.broadcast_add(input1)?;
values.insert(node.output[0].clone(), output);
}
"Sub" => {
let input0 = get(&node.input[0])?;
- let input1 = get(&node.input[0])?;
+ let input1 = get(&node.input[1])?;
let output = input0.broadcast_sub(input1)?;
values.insert(node.output[0].clone(), output);
}
"Mul" => {
let input0 = get(&node.input[0])?;
- let input1 = get(&node.input[0])?;
+ let input1 = get(&node.input[1])?;
let output = input0.broadcast_mul(input1)?;
values.insert(node.output[0].clone(), output);
}
"Div" => {
let input0 = get(&node.input[0])?;
- let input1 = get(&node.input[0])?;
+ let input1 = get(&node.input[1])?;
let output = input0.broadcast_div(input1)?;
values.insert(node.output[0].clone(), output);
}
"MatMul" => {
let input0 = get(&node.input[0])?;
- let input1 = get(&node.input[0])?;
+ let input1 = get(&node.input[1])?;
let output = input0.broadcast_matmul(input1)?;
values.insert(node.output[0].clone(), output);
}
@@ -67,14 +68,97 @@ pub fn simple_eval(
let output = input.relu()?;
values.insert(node.output[0].clone(), output);
}
- op_type => candle::bail!("unsupported op_type {op_type} for op {}", node.name),
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Constant
+ "Constant" => {
+ let value = match node.attribute.iter().find(|attr| attr.name == "value") {
+ None => {
+ // TODO: support sparse_value etc.
+ bail!("cannot find 'value' attr in 'Constant' for {}", node.name)
+ }
+ Some(value) => value,
+ };
+ let output = match value.r#type() {
+ AttributeType::Tensor => {
+ use crate::onnx::tensor_proto::DataType;
+ let t = value.t.as_ref().unwrap();
+ let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
+ match DataType::try_from(t.data_type) {
+ Ok(DataType::Uint8) => Tensor::from_raw_buffer(
+ t.raw_data.as_slice(),
+ DType::U8,
+ dims.as_slice(),
+ &Device::Cpu,
+ )?,
+ Ok(DataType::Uint32) => Tensor::from_raw_buffer(
+ t.raw_data.as_slice(),
+ DType::U32,
+ dims.as_slice(),
+ &Device::Cpu,
+ )?,
+ Ok(DataType::Int64) => Tensor::from_raw_buffer(
+ t.raw_data.as_slice(),
+ DType::I64,
+ dims.as_slice(),
+ &Device::Cpu,
+ )?,
+ Ok(DataType::Float16) => Tensor::from_raw_buffer(
+ t.raw_data.as_slice(),
+ DType::F16,
+ dims.as_slice(),
+ &Device::Cpu,
+ )?,
+ Ok(DataType::Float) => Tensor::from_raw_buffer(
+ t.raw_data.as_slice(),
+ DType::F32,
+ dims.as_slice(),
+ &Device::Cpu,
+ )?,
+ Ok(DataType::Double) => Tensor::from_raw_buffer(
+ t.raw_data.as_slice(),
+ DType::F64,
+ dims.as_slice(),
+ &Device::Cpu,
+ )?,
+ Ok(dt) => {
+ bail!("unsupported 'value' data-type {dt:?} for {}", node.name)
+ }
+ Err(_) => {
+ bail!(
+ "unsupported 'value' data-type {} for {}",
+ t.data_type,
+ node.name
+ )
+ }
+ }
+ }
+ rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name),
+ };
+ values.insert(node.output[0].clone(), output);
+ }
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
+ "Cast" => {
+ let input = get(&node.input[0])?;
+ let dtype = match node.attribute.iter().find(|attr| attr.name == "to") {
+ None => {
+ bail!("cannot find the 'to' attribute in 'Cast' for {}", node.name)
+ }
+ Some(dtype) => match dtype.r#type() {
+ AttributeType::Floats => candle::DType::F32,
+ AttributeType::Int => candle::DType::I64,
+ rtype => bail!("unsupported 'to' type {rtype:?} for {}", node.name),
+ },
+ };
+ let output = input.to_dtype(dtype)?;
+ values.insert(node.output[0].clone(), output);
+ }
+ op_type => bail!("unsupported op_type {op_type} for op {}", node.name),
}
}
graph
.output
.iter()
.map(|output| match values.remove(&output.name) {
- None => candle::bail!("cannot find output {}", output.name),
+ None => bail!("cannot find output {}", output.name),
Some(value) => Ok((output.name.clone(), value)),
})
.collect()