summaryrefslogtreecommitdiff
path: root/candle-onnx
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-11-05 16:57:26 +0100
committerGitHub <noreply@github.com>2023-11-05 16:57:26 +0100
commitf365a075e551dd50f7def29ecc2d8cba100c4625 (patch)
tree989c49f3317c59941284048c498e7801151b5e7c /candle-onnx
parent60fdab4e17d3e420f20610ec75df3deccd8e1f69 (diff)
downloadcandle-f365a075e551dd50f7def29ecc2d8cba100c4625.tar.gz
candle-f365a075e551dd50f7def29ecc2d8cba100c4625.tar.bz2
candle-f365a075e551dd50f7def29ecc2d8cba100c4625.zip
Add more models to the onnx example. (#1273)
* Add more models to the onnx example. * Input validation. * Input validation. * Bugfix. * Implement clip. * BatchNorm support. * Get the efficientnet onnx to work.
Diffstat (limited to 'candle-onnx')
-rw-r--r--candle-onnx/src/eval.rs167
1 files changed, 152 insertions, 15 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index c1c98101..54fae6c1 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -30,6 +30,13 @@ impl Attr for i64 {
}
}
+impl Attr for f32 {
+ const TYPE: AttributeType = AttributeType::Float;
+ fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
+ Ok(&attr.f)
+ }
+}
+
impl Attr for [i64] {
const TYPE: AttributeType = AttributeType::Ints;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
@@ -134,12 +141,66 @@ pub fn simple_eval(
None => bail!("no graph defined in proto"),
Some(graph) => graph,
};
- // TODO: validate the inputs.
let mut values = inputs;
for t in graph.initializer.iter() {
let tensor = get_tensor(t, t.name.as_str())?;
values.insert(t.name.to_string(), tensor);
}
+ for input in graph.input.iter() {
+ let input_type = match &input.r#type {
+ Some(input_type) => input_type,
+ None => continue,
+ };
+ let input_type = match &input_type.value {
+ Some(input_type) => input_type,
+ None => continue,
+ };
+ let tensor_type = match input_type {
+ onnx::type_proto::Value::TensorType(tt) => tt,
+ _ => continue,
+ };
+
+ let tensor = match values.get(&input.name) {
+ None => bail!("missing input {}", input.name),
+ Some(tensor) => tensor,
+ };
+ let dt = match DataType::try_from(tensor_type.elem_type) {
+ Ok(dt) => match dtype(dt) {
+ Some(dt) => dt,
+ None => {
+ bail!("unsupported 'value' data-type {dt:?} for {}", input.name)
+ }
+ },
+ type_ => bail!("unsupported input type {type_:?}"),
+ };
+ let shape = match &tensor_type.shape {
+ None => continue,
+ Some(shape) => shape
+ .dim
+ .iter()
+ .map(|dim| match dim.value.as_ref().expect("no dim value") {
+ onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
+ onnx::tensor_shape_proto::dimension::Value::DimParam(_) => {
+ bail!("DimParam is unsupported for input {}", input.name)
+ }
+ })
+ .collect::<Result<Vec<usize>>>()?,
+ };
+ if dt != tensor.dtype() {
+ bail!(
+ "unexpected dtype for {}, got {:?}, expected {dt:?}",
+ input.name,
+ tensor.dtype()
+ )
+ }
+ if shape.as_slice() != tensor.dims() {
+ bail!(
+ "unexpected shape for {}, got {:?}, expected {shape:?}",
+ input.name,
+ tensor.dims()
+ )
+ }
+ }
// The nodes are topologically sorted so we can just process them in order.
for node in graph.node.iter() {
let get = |input_name: &str| match values.get(input_name) {
@@ -328,6 +389,79 @@ pub fn simple_eval(
};
values.insert(node.output[0].clone(), ys);
}
+ "BatchNormalization" => {
+ let training_mode = get_attr_opt::<i64>(node, "training_mode")?;
+ if training_mode.copied().unwrap_or(0) != 0 {
+ bail!("training mode is not supported for BatchNorm")
+ }
+ let eps = get_attr_opt::<f32>(node, "epsilon")?
+ .copied()
+ .unwrap_or(1e-5);
+ let xs = get(&node.input[0])?;
+ let weight = get(&node.input[1])?;
+ let bias = get(&node.input[2])?;
+ let running_mean = get(&node.input[3])?;
+ let running_var = get(&node.input[4])?;
+ let target_shape: Vec<usize> = xs
+ .dims()
+ .iter()
+ .enumerate()
+ .map(|(idx, v)| if idx == 1 { *v } else { 1 })
+ .collect();
+ let target_shape = target_shape.as_slice();
+ let xs = xs
+ .broadcast_sub(&running_mean.reshape(target_shape)?)?
+ .broadcast_div(&(running_var.reshape(target_shape)? + eps as f64)?.sqrt()?)?;
+ let weight = weight.reshape(target_shape)?;
+ let bias = bias.reshape(target_shape)?;
+ let xs = xs.broadcast_mul(&weight)?.broadcast_add(&bias)?;
+ values.insert(node.output[0].clone(), xs);
+ }
+ "Squeeze" => {
+ let xs = get(&node.input[0])?;
+ let mut axes = if node.input.len() <= 1 {
+ // contract all the dimensions with size 1 except the batch dim.
+ xs.dims()
+ .iter()
+ .enumerate()
+ .flat_map(|(idx, &s)| if s == 1 && idx > 0 { Some(idx) } else { None })
+ .collect()
+ } else {
+ get(&node.input[1])?
+ .to_vec1::<i64>()?
+ .iter()
+ .map(|&i| {
+ if i < 0 {
+ (xs.rank() as i64 + i) as usize
+ } else {
+ i as usize
+ }
+ })
+ .collect::<Vec<_>>()
+ };
+ axes.sort();
+ let mut xs = xs.clone();
+ for &axis in axes.iter().rev() {
+ xs = xs.squeeze(axis)?
+ }
+ values.insert(node.output[0].clone(), xs);
+ }
+ "Clip" => {
+ let xs = get(&node.input[0])?;
+ let xs = if node.input.len() >= 2 {
+ let mins = get(&node.input[1])?;
+ xs.broadcast_maximum(mins)?
+ } else {
+ xs.clone()
+ };
+ let xs = if node.input.len() >= 3 {
+ let maxs = get(&node.input[2])?;
+ xs.broadcast_minimum(maxs)?
+ } else {
+ xs.clone()
+ };
+ values.insert(node.output[0].clone(), xs);
+ }
"Conv" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
@@ -344,17 +478,15 @@ pub fn simple_eval(
let ws = get(&node.input[1])?;
let ys = match ws.rank() {
3 => {
- let pads = match pads {
- None => 0,
- Some([p]) => *p as usize,
+ let (pads, xs) = match pads {
+ None => (0, xs.clone()),
+ Some([p]) => (*p as usize, xs.clone()),
Some([p1, p2]) => {
if p1 != p2 {
- bail!(
- "left and right pad ({p1} <> {p2}) have to be the same {}",
- node.name
- )
+ (0usize, xs.pad_with_zeros(2, *p1 as usize, *p2 as usize)?)
+ } else {
+ (*p1 as usize, xs.clone())
}
- *p1 as usize
}
Some(pads) => {
bail!("more pads than expected in conv1d {pads:?} {}", node.name)
@@ -377,14 +509,19 @@ pub fn simple_eval(
xs.conv1d(ws, pads, strides, dilations, groups as usize)?
}
4 => {
- let pads = match pads {
- None => 0,
- Some([p]) => *p as usize,
- Some([p1, p2, p3, p4]) => {
+ let (pads, xs) = match pads {
+ None => (0, xs.clone()),
+ Some([p]) => (*p as usize, xs.clone()),
+ Some(&[p1, p2, p3, p4]) => {
+ let p1 = p1 as usize;
+ let p2 = p2 as usize;
+ let p3 = p3 as usize;
+ let p4 = p4 as usize;
if p1 != p2 || p1 != p3 || p1 != p4 {
- bail!("pads have to be the same {pads:?} {}", node.name)
+ (0, xs.pad_with_zeros(2, p1, p3)?.pad_with_zeros(3, p2, p4)?)
+ } else {
+ (p1, xs.clone())
}
- *p1 as usize
}
Some(pads) => {
bail!("more pads than expected in conv2d {pads:?} {}", node.name)