summaryrefslogtreecommitdiff
path: root/candle-onnx/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-11-06 22:44:58 +0100
committerGitHub <noreply@github.com>2023-11-06 22:44:58 +0100
commita773a4b22b88d9955f51de552d72717441d49729 (patch)
treec09ea4f12a02a0d7cf68f6efe51045e125b76367 /candle-onnx/src
parent5a363dbc263abaa8a71f5bf606351d8d458ce70a (diff)
downloadcandle-a773a4b22b88d9955f51de552d72717441d49729.tar.gz
candle-a773a4b22b88d9955f51de552d72717441d49729.tar.bz2
candle-a773a4b22b88d9955f51de552d72717441d49729.zip
[ONNX] Support a couple more ops. (#1284)
* Support the shape op in ONNX. * Share the axis normalization bits. * Add some limited support for gather. * Unsqueeze. * Comparison with broadcasting. * Add Not + handle i32.
Diffstat (limited to 'candle-onnx/src')
-rw-r--r--candle-onnx/src/eval.rs135
1 files changed, 109 insertions, 26 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index 54fae6c1..51e2aa0c 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -101,6 +101,18 @@ fn get_attr_opt<'a, T: Attr + ?Sized>(
fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
match DataType::try_from(t.data_type) {
+ Ok(DataType::Int32) => {
+ if t.int32_data.is_empty() {
+ let len = t.raw_data.len() / 4;
+ let data: &[i32] =
+ unsafe { std::slice::from_raw_parts(t.raw_data.as_ptr() as *const i32, len) };
+ let data = data.iter().map(|v| *v as i64).collect::<Vec<_>>();
+ Tensor::from_vec(data, len, &Device::Cpu)
+ } else {
+ let data = t.int32_data.iter().map(|v| *v as i64).collect::<Vec<_>>();
+ Tensor::from_vec(data, t.int32_data.len(), &Device::Cpu)
+ }
+ }
Ok(dt) => match dtype(dt) {
Some(dt) => {
if dt == DType::F32 && !t.float_data.is_empty() {
@@ -173,18 +185,34 @@ pub fn simple_eval(
},
type_ => bail!("unsupported input type {type_:?}"),
};
- let shape = match &tensor_type.shape {
+ match &tensor_type.shape {
None => continue,
- Some(shape) => shape
- .dim
- .iter()
- .map(|dim| match dim.value.as_ref().expect("no dim value") {
- onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
- onnx::tensor_shape_proto::dimension::Value::DimParam(_) => {
- bail!("DimParam is unsupported for input {}", input.name)
+ Some(shape) => {
+ if shape.dim.len() != tensor.rank() {
+ bail!(
+ "unexpected rank for {}, got {:?}, expected {:?}",
+ input.name,
+ shape.dim,
+ tensor.shape()
+ )
+ }
+ for (idx, (d, &dim)) in shape.dim.iter().zip(tensor.dims().iter()).enumerate() {
+ match &d.value {
+ Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)) => {
+ if *v as usize != dim {
+ bail!(
+ "unexpected dim {idx} for {}, got {:?}, expected {:?}",
+ input.name,
+ shape.dim,
+ tensor.shape()
+ )
+ }
+ }
+ // We do not check equality constraints for the DimParam dimensions for now.
+ Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) | None => (),
}
- })
- .collect::<Result<Vec<usize>>>()?,
+ }
+ }
};
if dt != tensor.dtype() {
bail!(
@@ -193,13 +221,6 @@ pub fn simple_eval(
tensor.dtype()
)
}
- if shape.as_slice() != tensor.dims() {
- bail!(
- "unexpected shape for {}, got {:?}, expected {shape:?}",
- input.name,
- tensor.dims()
- )
- }
}
// The nodes are topologically sorted so we can just process them in order.
for node in graph.node.iter() {
@@ -236,9 +257,14 @@ pub fn simple_eval(
"Equal" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
- let output = input0.eq(input1)?;
+ let output = input0.broadcast_eq(input1)?;
values.insert(node.output[0].clone(), output);
}
+ "Not" => {
+ let xs = get(&node.input[0])?;
+ let xs = xs.eq(&xs.zeros_like()?)?;
+ values.insert(node.output[0].clone(), xs);
+ }
"MatMul" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
@@ -430,14 +456,8 @@ pub fn simple_eval(
get(&node.input[1])?
.to_vec1::<i64>()?
.iter()
- .map(|&i| {
- if i < 0 {
- (xs.rank() as i64 + i) as usize
- } else {
- i as usize
- }
- })
- .collect::<Vec<_>>()
+ .map(|&i| xs.normalize_axis(i))
+ .collect::<Result<Vec<_>>>()?
};
axes.sort();
let mut xs = xs.clone();
@@ -446,6 +466,39 @@ pub fn simple_eval(
}
values.insert(node.output[0].clone(), xs);
}
+ "ConstantOfShape" => {
+ let dims = get(&node.input[0])?;
+ let shape = dims
+ .to_vec1::<i64>()?
+ .into_iter()
+ .map(|v| v as usize)
+ .collect::<Vec<_>>();
+ let xs = Tensor::zeros(shape, DType::F32, dims.device())?;
+ values.insert(node.output[0].clone(), xs);
+ }
+ "Unsqueeze" => {
+ let xs = get(&node.input[0])?;
+ let axes = match get_attr_opt::<[i64]>(node, "axes")? {
+ Some(axis) => axis.to_vec(),
+ None => get(&node.input[1])?.to_vec1::<i64>()?,
+ };
+ let mut axes = axes
+ .iter()
+ .map(|&i| {
+ if i == xs.rank() as i64 {
+ Ok(xs.rank())
+ } else {
+ xs.normalize_axis(i)
+ }
+ })
+ .collect::<Result<Vec<_>>>()?;
+ axes.sort();
+ let mut xs = xs.clone();
+ for &axis in axes.iter().rev() {
+ xs = xs.unsqueeze(axis)?
+ }
+ values.insert(node.output[0].clone(), xs);
+ }
"Clip" => {
let xs = get(&node.input[0])?;
let xs = if node.input.len() >= 2 {
@@ -462,6 +515,35 @@ pub fn simple_eval(
};
values.insert(node.output[0].clone(), xs);
}
+ "Gather" => {
+ let xs = get(&node.input[0])?;
+ let indices = get(&node.input[1])?;
+ let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
+ let axis = xs.normalize_axis(axis)?;
+ // TODO: Provide an op to handle the ONNX generalized gather op ideally in a
+ // differentiable way.
+ let xs = if indices.rank() == 0 {
+ let index = indices.to_vec0::<i64>()? as usize;
+ xs.narrow(axis, index, 1)?.squeeze(axis)?
+ } else {
+ todo!("implement gather for {xs:?} {indices:?} axis {axis}")
+ };
+ values.insert(node.output[0].clone(), xs);
+ }
+ "Shape" => {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
+ let xs = get(&node.input[0])?;
+ let start = get_attr_opt::<i64>(node, "start")?.copied().unwrap_or(0);
+ let end = get_attr_opt::<i64>(node, "end")?.copied().unwrap_or(-1);
+ let start = xs.normalize_axis(start)?;
+ let end = xs.normalize_axis(end)?;
+ let mut dims = vec![];
+ for idx in start..=end {
+ dims.push(xs.dim(idx)? as i64)
+ }
+ let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
+ values.insert(node.output[0].clone(), dims);
+ }
"Conv" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
@@ -670,6 +752,7 @@ pub fn simple_eval(
let input = get(&node.input[0])?;
let dt: i64 = *get_attr(node, "to")?;
let dtype = match DataType::try_from(dt as i32) {
+ Ok(DataType::Int32) => DType::I64,
Ok(dt) => match dtype(dt) {
Some(dt) => dt,
None => {