use crate::onnx; use crate::onnx::attribute_proto::AttributeType; use crate::onnx::tensor_proto::DataType; use candle::{bail, DType, Device, Result, Tensor}; use std::collections::HashMap; pub type Value = Tensor; pub fn dtype(dt: DataType) -> Option { match dt { DataType::Uint8 => Some(DType::U8), DataType::Uint32 => Some(DType::U32), DataType::Int64 => Some(DType::I64), DataType::Float16 => Some(DType::F16), DataType::Float => Some(DType::F32), DataType::Double => Some(DType::F64), _ => None, } } trait Attr { const TYPE: AttributeType; fn get(attr: &onnx::AttributeProto) -> Result<&Self>; } impl Attr for i64 { const TYPE: AttributeType = AttributeType::Int; fn get(attr: &onnx::AttributeProto) -> Result<&Self> { Ok(&attr.i) } } impl Attr for f32 { const TYPE: AttributeType = AttributeType::Float; fn get(attr: &onnx::AttributeProto) -> Result<&Self> { Ok(&attr.f) } } impl Attr for [i64] { const TYPE: AttributeType = AttributeType::Ints; fn get(attr: &onnx::AttributeProto) -> Result<&Self> { Ok(attr.ints.as_slice()) } } impl Attr for str { const TYPE: AttributeType = AttributeType::String; fn get(attr: &onnx::AttributeProto) -> Result<&Self> { std::str::from_utf8(&attr.s).map_err(candle::Error::wrap) } } fn get_attr_<'a>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a onnx::AttributeProto> { match node.attribute.iter().find(|attr| attr.name == name) { None => { bail!( "cannot find the '{name}' attribute in '{}' for {}", node.op_type, node.name ) } Some(dt) => Ok(dt), } } fn get_attr<'a, T: Attr + ?Sized>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a T> { let attr = get_attr_(node, name)?; if attr.r#type() != T::TYPE { bail!( "unsupported type {:?} for '{name}' attribute in '{}' for {}", attr.r#type, node.op_type, node.name ) } T::get(attr) } fn get_attr_opt<'a, T: Attr + ?Sized>( node: &'a onnx::NodeProto, name: &str, ) -> Result> { match node.attribute.iter().find(|attr| attr.name == name) { None => Ok(None), Some(attr) => { if attr.r#type() != T::TYPE { bail!( "unsupported type {:?} for '{name}' attribute in '{}' for {}", attr.r#type, node.op_type, node.name ) } let val = T::get(attr)?; Ok(Some(val)) } } } pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result { let dims: Vec = t.dims.iter().map(|&x| x as usize).collect(); match DataType::try_from(t.data_type) { Ok(DataType::Int32) => { if t.int32_data.is_empty() { let len = t.raw_data.len() / 4; let data: &[i32] = unsafe { std::slice::from_raw_parts(t.raw_data.as_ptr() as *const i32, len) }; let data = data.iter().map(|v| *v as i64).collect::>(); Tensor::from_vec(data, len, &Device::Cpu) } else { let data = t.int32_data.iter().map(|v| *v as i64).collect::>(); Tensor::from_vec(data, t.int32_data.len(), &Device::Cpu) } } Ok(dt) => match dtype(dt) { Some(dt) => { if dt == DType::F32 && !t.float_data.is_empty() { Tensor::from_slice(&t.float_data, dims.as_slice(), &Device::Cpu) } else if dt == DType::F64 && !t.double_data.is_empty() { Tensor::from_slice(&t.double_data, dims.as_slice(), &Device::Cpu) } else if dt == DType::I64 && !t.int64_data.is_empty() { Tensor::from_slice(&t.int64_data, dims.as_slice(), &Device::Cpu) } else { Tensor::from_raw_buffer( t.raw_data.as_slice(), dt, dims.as_slice(), &Device::Cpu, ) } } None => { bail!("unsupported 'value' data-type {dt:?} for {name}") } }, Err(_) => { bail!("unsupported 'value' data-type {} for {name}", t.data_type,) } } } // This function provides a direct evaluation of the proto. // Longer-term, we should first convert the proto to an intermediate representation of the compute // graph so as to make multiple evaluations more efficient. // An example upside of this would be to remove intermediary values when they are not needed // anymore. pub fn simple_eval( model: &onnx::ModelProto, inputs: HashMap, ) -> Result> { let graph = match &model.graph { None => bail!("no graph defined in proto"), Some(graph) => graph, }; let mut values = inputs; for t in graph.initializer.iter() { let tensor = get_tensor(t, t.name.as_str())?; values.insert(t.name.to_string(), tensor); } for input in graph.input.iter() { let input_type = match &input.r#type { Some(input_type) => input_type, None => continue, }; let input_type = match &input_type.value { Some(input_type) => input_type, None => continue, }; let tensor_type = match input_type { onnx::type_proto::Value::TensorType(tt) => tt, _ => continue, }; let tensor = match values.get(&input.name) { None => bail!("missing input {}", input.name), Some(tensor) => tensor, }; let dt = match DataType::try_from(tensor_type.elem_type) { Ok(dt) => match dtype(dt) { Some(dt) => dt, None => { bail!("unsupported 'value' data-type {dt:?} for {}", input.name) } }, type_ => bail!("unsupported input type {type_:?}"), }; match &tensor_type.shape { None => continue, Some(shape) => { if shape.dim.len() != tensor.rank() { bail!( "unexpected rank for {}, got {:?}, expected {:?}", input.name, shape.dim, tensor.shape() ) } for (idx, (d, &dim)) in shape.dim.iter().zip(tensor.dims().iter()).enumerate() { match &d.value { Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)) => { if *v as usize != dim { bail!( "unexpected dim {idx} for {}, got {:?}, expected {:?}", input.name, shape.dim, tensor.shape() ) } } // We do not check equality constraints for the DimParam dimensions for now. Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) | None => (), } } } }; if dt != tensor.dtype() { bail!( "unexpected dtype for {}, got {:?}, expected {dt:?}", input.name, tensor.dtype() ) } } // The nodes are topologically sorted so we can just process them in order. for node in graph.node.iter() { let get = |input_name: &str| match values.get(input_name) { Some(value) => Ok(value), None => bail!("cannot find {input_name} for op {}", node.name), }; // TODO: Validate node.input for each operator. match node.op_type.as_str() { "Add" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; let output = input0.broadcast_add(input1)?; values.insert(node.output[0].clone(), output); } "Sub" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; let output = input0.broadcast_sub(input1)?; values.insert(node.output[0].clone(), output); } "Mul" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; let output = input0.broadcast_mul(input1)?; values.insert(node.output[0].clone(), output); } "Div" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; let output = input0.broadcast_div(input1)?; values.insert(node.output[0].clone(), output); } "Pow" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; let output = input0.broadcast_pow(input1)?; values.insert(node.output[0].clone(), output); } "Equal" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; let output = input0.broadcast_eq(input1)?; values.insert(node.output[0].clone(), output); } "Not" => { let xs = get(&node.input[0])?; let xs = xs.eq(&xs.zeros_like()?)?; values.insert(node.output[0].clone(), xs); } "MatMul" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; let output = input0.broadcast_matmul(input1)?; values.insert(node.output[0].clone(), output); } "Reshape" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?.to_vec1::()?; // TODO: Check that there is at most a single -1 or 0, handle other neg values. let mut other_than_minus1 = 1usize; for &v in input1.iter() { if v != -1 && v != 0 { other_than_minus1 *= v as usize } } let input1 = input1 .iter() .enumerate() .map(|(idx, &v)| match v { -1 => Ok(input0.elem_count() / other_than_minus1), 0 => input0.dim(idx), _ => Ok(v as usize), }) .collect::>>()?; let output = input0.reshape(input1)?; values.insert(node.output[0].clone(), output); } "LogSoftmax" => { let input = get(&node.input[0])?; let output = match get_attr_opt::(node, "axis")? { None => candle_nn::ops::softmax_last_dim(input)?, Some(&axis) => { let axis = input.normalize_axis(axis)?; candle_nn::ops::log_softmax(input, axis)? } }; values.insert(node.output[0].clone(), output); } "Softmax" => { let input = get(&node.input[0])?; let output = match get_attr_opt::(node, "axis")? { None => candle_nn::ops::softmax_last_dim(input)?, Some(&axis) => { let axis = input.normalize_axis(axis)?; candle_nn::ops::softmax(input, axis)? } }; values.insert(node.output[0].clone(), output); } "Transpose" => { let input = get(&node.input[0])?; let output = match get_attr_opt::<[i64]>(node, "perm")? { None => input.t()?, Some(perm) => { let perm = perm.iter().map(|&v| v as usize).collect::>(); input.permute(perm)? } }; values.insert(node.output[0].clone(), output); } "Dropout" => { let input = get(&node.input[0])?; // Do not apply dropout at the moment, consider that we're only doing inference. values.insert(node.output[0].clone(), input.clone()); } "MaxPool" => { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#MaxPool let dilations = get_attr_opt::<[i64]>(node, "dilations")?; let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?; let pads = get_attr_opt::<[i64]>(node, "pads")?; let strides = get_attr_opt::<[i64]>(node, "strides")?; let auto_pad = get_attr_opt::(node, "auto_pad")?; match auto_pad { None | Some("NOTSET") => (), Some(s) => bail!("unsupported auto_pad {s}"), }; if let Some(d) = dilations { if d.iter().any(|&v| v != 1) { bail!("MaxPool with dilation != 1, {dilations:?}") } } if let Some(d) = pads { if d.iter().any(|&v| v != 0) { bail!("MaxPool with pads != 0, {pads:?}") } } let xs = get(&node.input[0])?; let (k1, k2) = match kernel_shape { [k1, k2] => (*k1 as usize, *k2 as usize), _ => bail!("only 2d MaxPool is supported, kernel shape {kernel_shape:?}"), }; let ys = match strides { None => xs.max_pool2d((k1, k2))?, Some([s1, s2]) => { xs.max_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))? } Some(strides) => bail!("only 2d MaxPool is supported, strides {strides:?}"), }; values.insert(node.output[0].clone(), ys); } "AveragePool" => { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool let dilations = get_attr_opt::<[i64]>(node, "dilations")?; let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?; let pads = get_attr_opt::<[i64]>(node, "pads")?; let strides = get_attr_opt::<[i64]>(node, "strides")?; let auto_pad = get_attr_opt::(node, "auto_pad")?; match auto_pad { None | Some("NOTSET") => (), Some(s) => bail!("unsupported auto_pad {s}"), }; if let Some(d) = dilations { if d.iter().any(|&v| v != 1) { bail!("AvgPool with dilation != 1, {dilations:?}") } } if let Some(d) = pads { if d.iter().any(|&v| v != 0) { bail!("AvgPool with pads != 0, {pads:?}") } } let xs = get(&node.input[0])?; let (k1, k2) = match kernel_shape { [k1, k2] => (*k1 as usize, *k2 as usize), _ => bail!("only 2d AvgPool is supported, kernel shape {kernel_shape:?}"), }; let ys = match strides { None => xs.avg_pool2d((k1, k2))?, Some([s1, s2]) => { xs.avg_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))? } Some(strides) => bail!("only 2d AvgPool is supported, strides {strides:?}"), }; values.insert(node.output[0].clone(), ys); } "BatchNormalization" => { let training_mode = get_attr_opt::(node, "training_mode")?; if training_mode.copied().unwrap_or(0) != 0 { bail!("training mode is not supported for BatchNorm") } let eps = get_attr_opt::(node, "epsilon")? .copied() .unwrap_or(1e-5); let xs = get(&node.input[0])?; let weight = get(&node.input[1])?; let bias = get(&node.input[2])?; let running_mean = get(&node.input[3])?; let running_var = get(&node.input[4])?; let target_shape: Vec = xs .dims() .iter() .enumerate() .map(|(idx, v)| if idx == 1 { *v } else { 1 }) .collect(); let target_shape = target_shape.as_slice(); let xs = xs .broadcast_sub(&running_mean.reshape(target_shape)?)? .broadcast_div(&(running_var.reshape(target_shape)? + eps as f64)?.sqrt()?)?; let weight = weight.reshape(target_shape)?; let bias = bias.reshape(target_shape)?; let xs = xs.broadcast_mul(&weight)?.broadcast_add(&bias)?; values.insert(node.output[0].clone(), xs); } "Squeeze" => { let xs = get(&node.input[0])?; let mut axes = if node.input.len() <= 1 { // contract all the dimensions with size 1 except the batch dim. xs.dims() .iter() .enumerate() .flat_map(|(idx, &s)| if s == 1 && idx > 0 { Some(idx) } else { None }) .collect() } else { get(&node.input[1])? .to_vec1::()? .iter() .map(|&i| xs.normalize_axis(i)) .collect::>>()? }; axes.sort(); let mut xs = xs.clone(); for &axis in axes.iter().rev() { xs = xs.squeeze(axis)? } values.insert(node.output[0].clone(), xs); } "ConstantOfShape" => { let dims = get(&node.input[0])?; let shape = dims .to_vec1::()? .into_iter() .map(|v| v as usize) .collect::>(); let xs = Tensor::zeros(shape, DType::F32, dims.device())?; values.insert(node.output[0].clone(), xs); } "Unsqueeze" => { let xs = get(&node.input[0])?; let axes = match get_attr_opt::<[i64]>(node, "axes")? { Some(axis) => axis.to_vec(), None => get(&node.input[1])?.to_vec1::()?, }; let mut axes = axes .iter() .map(|&i| { if i == xs.rank() as i64 { Ok(xs.rank()) } else { xs.normalize_axis(i) } }) .collect::>>()?; axes.sort(); let mut xs = xs.clone(); for &axis in axes.iter().rev() { xs = xs.unsqueeze(axis)? } values.insert(node.output[0].clone(), xs); } "Clip" => { let xs = get(&node.input[0])?; let xs = if node.input.len() >= 2 { let mins = get(&node.input[1])?; xs.broadcast_maximum(mins)? } else { xs.clone() }; let xs = if node.input.len() >= 3 { let maxs = get(&node.input[2])?; xs.broadcast_minimum(maxs)? } else { xs.clone() }; values.insert(node.output[0].clone(), xs); } "Gather" => { let xs = get(&node.input[0])?; let indices = get(&node.input[1])?; let axis = get_attr_opt::(node, "axis")?.copied().unwrap_or(0); let axis = xs.normalize_axis(axis)?; // TODO: Provide an op to handle the ONNX generalized gather op ideally in a // differentiable way. let xs = if indices.rank() == 0 { let index = indices.to_vec0::()? as usize; xs.narrow(axis, index, 1)?.squeeze(axis)? } else { todo!("implement gather for {xs:?} {indices:?} axis {axis}") }; values.insert(node.output[0].clone(), xs); } "Shape" => { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape let xs = get(&node.input[0])?; let start = get_attr_opt::(node, "start")?.copied().unwrap_or(0); let end = get_attr_opt::(node, "end")?.copied().unwrap_or(-1); let start = xs.normalize_axis(start)?; let end = xs.normalize_axis(end)?; let mut dims = vec![]; for idx in start..=end { dims.push(xs.dim(idx)? as i64) } let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?; values.insert(node.output[0].clone(), dims); } "Conv" => { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv let dilations = get_attr_opt::<[i64]>(node, "dilations")?; let groups = get_attr_opt::(node, "group")?.copied().unwrap_or(1); let _kernel_shape = get_attr_opt::<[i64]>(node, "kernel_shape")?; let pads = get_attr_opt::<[i64]>(node, "pads")?; let strides = get_attr_opt::<[i64]>(node, "strides")?; let auto_pad = get_attr_opt::(node, "auto_pad")?; match auto_pad { None | Some("NOTSET") => (), Some(s) => bail!("unsupported auto_pad {s}"), }; let xs = get(&node.input[0])?; let ws = get(&node.input[1])?; let ys = match ws.rank() { 3 => { let (pads, xs) = match pads { None => (0, xs.clone()), Some([p]) => (*p as usize, xs.clone()), Some([p1, p2]) => { if p1 != p2 { (0usize, xs.pad_with_zeros(2, *p1 as usize, *p2 as usize)?) } else { (*p1 as usize, xs.clone()) } } Some(pads) => { bail!("more pads than expected in conv1d {pads:?} {}", node.name) } }; let strides = match strides { None => 1, Some([p]) => *p as usize, Some(s) => { bail!("more strides than expected in conv1d {s:?} {}", node.name) } }; let dilations = match dilations { None => 1, Some([p]) => *p as usize, Some(s) => { bail!("more dilations than expected in conv1d {s:?} {}", node.name) } }; xs.conv1d(ws, pads, strides, dilations, groups as usize)? } 4 => { let (pads, xs) = match pads { None => (0, xs.clone()), Some([p]) => (*p as usize, xs.clone()), Some(&[p1, p2, p3, p4]) => { let p1 = p1 as usize; let p2 = p2 as usize; let p3 = p3 as usize; let p4 = p4 as usize; if p1 != p2 || p1 != p3 || p1 != p4 { (0, xs.pad_with_zeros(2, p1, p3)?.pad_with_zeros(3, p2, p4)?) } else { (p1, xs.clone()) } } Some(pads) => { bail!("more pads than expected in conv2d {pads:?} {}", node.name) } }; let strides = match strides { None => 1, Some([p]) => *p as usize, Some([p1, p2]) => { if p1 != p2 { bail!( "strides have to be the same on both axis {pads:?} {}", node.name ) } *p1 as usize } Some(s) => { bail!("more strides than expected in conv2d {s:?} {}", node.name) } }; let dilations = match dilations { None => 1, Some([p]) => *p as usize, Some([p1, p2]) => { if p1 != p2 { bail!( "dilations have to be the same on both axis {pads:?} {}", node.name ) } *p1 as usize } Some(s) => { bail!("more dilations than expected in conv2d {s:?} {}", node.name) } }; xs.conv2d(ws, pads, strides, dilations, groups as usize)? } rank => bail!( "unsupported rank for weight matrix {rank} in conv {}", node.name ), }; let ys = if node.input.len() > 2 { let bs = get(&node.input[2])?; let mut bs_shape = vec![1; ys.rank()]; bs_shape[1] = bs.elem_count(); ys.broadcast_add(&bs.reshape(bs_shape)?)? } else { ys }; values.insert(node.output[0].clone(), ys); } "Concat" => { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat let inputs = node .input .iter() .map(|n| Ok(get(n.as_str())?.clone())) .collect::>>()?; let axis: i64 = *get_attr(node, "axis")?; if inputs.is_empty() { bail!("empty concat") }; let axis = inputs[0].normalize_axis(axis)?; let output = Tensor::cat(&inputs, axis)?; values.insert(node.output[0].clone(), output); } "Abs" => { let input = get(&node.input[0])?; let output = input.abs()?; values.insert(node.output[0].clone(), output); } "Cos" => { let input = get(&node.input[0])?; let output = input.cos()?; values.insert(node.output[0].clone(), output); } "Sin" => { let input = get(&node.input[0])?; let output = input.sin()?; values.insert(node.output[0].clone(), output); } "Neg" => { let input = get(&node.input[0])?; let output = input.neg()?; values.insert(node.output[0].clone(), output); } "Erf" => { let input = get(&node.input[0])?; let output = input.erf()?; values.insert(node.output[0].clone(), output); } "Tanh" => { let input = get(&node.input[0])?; let output = input.tanh()?; values.insert(node.output[0].clone(), output); } "Sigmoid" => { let input = get(&node.input[0])?; let output = candle_nn::ops::sigmoid(input)?; values.insert(node.output[0].clone(), output); } "Gelu" => { let input = get(&node.input[0])?; let output = input.gelu_erf()?; values.insert(node.output[0].clone(), output); } "Relu" => { let input = get(&node.input[0])?; let output = input.relu()?; values.insert(node.output[0].clone(), output); } // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Constant "Constant" => { let value = match node.attribute.iter().find(|attr| attr.name == "value") { None => { // TODO: support sparse_value etc. bail!("cannot find 'value' attr in 'Constant' for {}", node.name) } Some(value) => value, }; let output = match value.r#type() { AttributeType::Tensor => { let t = value.t.as_ref().unwrap(); get_tensor(t, &node.name)? } rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name), }; values.insert(node.output[0].clone(), output); } // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast "Cast" => { let input = get(&node.input[0])?; let dt: i64 = *get_attr(node, "to")?; let dtype = match DataType::try_from(dt as i32) { Ok(DataType::Int32) => DType::I64, Ok(dt) => match dtype(dt) { Some(dt) => dt, None => { bail!("unsupported 'to' value {dt:?} for cast {}", node.name) } }, Err(_) => { bail!("unsupported 'to' value {dt:?} for cast {}", node.name) } }; let output = input.to_dtype(dtype)?; values.insert(node.output[0].clone(), output); } // https://github.com/onnx/onnx/blob/main/docs/Operators.md#CumSum "CumSum" => { let exclusive = get_attr_opt::(node, "exclusive")? .copied() .unwrap_or(0); let reverse = get_attr_opt::(node, "reverse")?.copied().unwrap_or(0); if exclusive != 0 { bail!("only exclusive == 0 is supported in CumSum") } if reverse != 0 { bail!("only reverse == 0 is supported in CumSum") } let input = get(&node.input[0])?; let axis = get(&node.input[1])? .to_dtype(DType::U32)? .to_vec0::()?; let output = input.cumsum(axis as usize)?; values.insert(node.output[0].clone(), output); } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } } graph .output .iter() .map(|output| match values.remove(&output.name) { None => bail!("cannot find output {}", output.name), Some(value) => Ok((output.name.clone(), value)), }) .collect() }