diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-11-06 22:44:58 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-06 22:44:58 +0100 |
commit | a773a4b22b88d9955f51de552d72717441d49729 (patch) | |
tree | c09ea4f12a02a0d7cf68f6efe51045e125b76367 /candle-onnx/src | |
parent | 5a363dbc263abaa8a71f5bf606351d8d458ce70a (diff) | |
download | candle-a773a4b22b88d9955f51de552d72717441d49729.tar.gz candle-a773a4b22b88d9955f51de552d72717441d49729.tar.bz2 candle-a773a4b22b88d9955f51de552d72717441d49729.zip |
[ONNX] Support a couple more ops. (#1284)
* Support the shape op in ONNX.
* Share the axis normalization bits.
* Add some limited support for gather.
* Unsqueeze.
* Comparison with broadcasting.
* Add Not + handle i32.
Diffstat (limited to 'candle-onnx/src')
-rw-r--r-- | candle-onnx/src/eval.rs | 135 |
1 files changed, 109 insertions, 26 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 54fae6c1..51e2aa0c 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -101,6 +101,18 @@ fn get_attr_opt<'a, T: Attr + ?Sized>( fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> { let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect(); match DataType::try_from(t.data_type) { + Ok(DataType::Int32) => { + if t.int32_data.is_empty() { + let len = t.raw_data.len() / 4; + let data: &[i32] = + unsafe { std::slice::from_raw_parts(t.raw_data.as_ptr() as *const i32, len) }; + let data = data.iter().map(|v| *v as i64).collect::<Vec<_>>(); + Tensor::from_vec(data, len, &Device::Cpu) + } else { + let data = t.int32_data.iter().map(|v| *v as i64).collect::<Vec<_>>(); + Tensor::from_vec(data, t.int32_data.len(), &Device::Cpu) + } + } Ok(dt) => match dtype(dt) { Some(dt) => { if dt == DType::F32 && !t.float_data.is_empty() { @@ -173,18 +185,34 @@ pub fn simple_eval( }, type_ => bail!("unsupported input type {type_:?}"), }; - let shape = match &tensor_type.shape { + match &tensor_type.shape { None => continue, - Some(shape) => shape - .dim - .iter() - .map(|dim| match dim.value.as_ref().expect("no dim value") { - onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize), - onnx::tensor_shape_proto::dimension::Value::DimParam(_) => { - bail!("DimParam is unsupported for input {}", input.name) + Some(shape) => { + if shape.dim.len() != tensor.rank() { + bail!( + "unexpected rank for {}, got {:?}, expected {:?}", + input.name, + shape.dim, + tensor.shape() + ) + } + for (idx, (d, &dim)) in shape.dim.iter().zip(tensor.dims().iter()).enumerate() { + match &d.value { + Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)) => { + if *v as usize != dim { + bail!( + "unexpected dim {idx} for {}, got {:?}, expected {:?}", + input.name, + shape.dim, + tensor.shape() + ) + } + } + // We do not check equality constraints for the DimParam dimensions for now. + Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) | None => (), } - }) - .collect::<Result<Vec<usize>>>()?, + } + } }; if dt != tensor.dtype() { bail!( @@ -193,13 +221,6 @@ pub fn simple_eval( tensor.dtype() ) } - if shape.as_slice() != tensor.dims() { - bail!( - "unexpected shape for {}, got {:?}, expected {shape:?}", - input.name, - tensor.dims() - ) - } } // The nodes are topologically sorted so we can just process them in order. for node in graph.node.iter() { @@ -236,9 +257,14 @@ pub fn simple_eval( "Equal" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; - let output = input0.eq(input1)?; + let output = input0.broadcast_eq(input1)?; values.insert(node.output[0].clone(), output); } + "Not" => { + let xs = get(&node.input[0])?; + let xs = xs.eq(&xs.zeros_like()?)?; + values.insert(node.output[0].clone(), xs); + } "MatMul" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; @@ -430,14 +456,8 @@ pub fn simple_eval( get(&node.input[1])? .to_vec1::<i64>()? .iter() - .map(|&i| { - if i < 0 { - (xs.rank() as i64 + i) as usize - } else { - i as usize - } - }) - .collect::<Vec<_>>() + .map(|&i| xs.normalize_axis(i)) + .collect::<Result<Vec<_>>>()? }; axes.sort(); let mut xs = xs.clone(); @@ -446,6 +466,39 @@ pub fn simple_eval( } values.insert(node.output[0].clone(), xs); } + "ConstantOfShape" => { + let dims = get(&node.input[0])?; + let shape = dims + .to_vec1::<i64>()? + .into_iter() + .map(|v| v as usize) + .collect::<Vec<_>>(); + let xs = Tensor::zeros(shape, DType::F32, dims.device())?; + values.insert(node.output[0].clone(), xs); + } + "Unsqueeze" => { + let xs = get(&node.input[0])?; + let axes = match get_attr_opt::<[i64]>(node, "axes")? { + Some(axis) => axis.to_vec(), + None => get(&node.input[1])?.to_vec1::<i64>()?, + }; + let mut axes = axes + .iter() + .map(|&i| { + if i == xs.rank() as i64 { + Ok(xs.rank()) + } else { + xs.normalize_axis(i) + } + }) + .collect::<Result<Vec<_>>>()?; + axes.sort(); + let mut xs = xs.clone(); + for &axis in axes.iter().rev() { + xs = xs.unsqueeze(axis)? + } + values.insert(node.output[0].clone(), xs); + } "Clip" => { let xs = get(&node.input[0])?; let xs = if node.input.len() >= 2 { @@ -462,6 +515,35 @@ pub fn simple_eval( }; values.insert(node.output[0].clone(), xs); } + "Gather" => { + let xs = get(&node.input[0])?; + let indices = get(&node.input[1])?; + let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0); + let axis = xs.normalize_axis(axis)?; + // TODO: Provide an op to handle the ONNX generalized gather op ideally in a + // differentiable way. + let xs = if indices.rank() == 0 { + let index = indices.to_vec0::<i64>()? as usize; + xs.narrow(axis, index, 1)?.squeeze(axis)? + } else { + todo!("implement gather for {xs:?} {indices:?} axis {axis}") + }; + values.insert(node.output[0].clone(), xs); + } + "Shape" => { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape + let xs = get(&node.input[0])?; + let start = get_attr_opt::<i64>(node, "start")?.copied().unwrap_or(0); + let end = get_attr_opt::<i64>(node, "end")?.copied().unwrap_or(-1); + let start = xs.normalize_axis(start)?; + let end = xs.normalize_axis(end)?; + let mut dims = vec![]; + for idx in start..=end { + dims.push(xs.dim(idx)? as i64) + } + let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?; + values.insert(node.output[0].clone(), dims); + } "Conv" => { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv let dilations = get_attr_opt::<[i64]>(node, "dilations")?; @@ -670,6 +752,7 @@ pub fn simple_eval( let input = get(&node.input[0])?; let dt: i64 = *get_attr(node, "to")?; let dtype = match DataType::try_from(dt as i32) { + Ok(DataType::Int32) => DType::I64, Ok(dt) => match dtype(dt) { Some(dt) => dt, None => { |