diff options
author | Gabriel <45515538+gabotechs@users.noreply.github.com> | 2024-04-08 14:06:14 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-08 14:06:14 +0200 |
commit | 798e0335cd2c4661f0fd0429cdf06abe3b45f4ea (patch) | |
tree | 680dde236758afe9dc864bc245be61dc53459da7 /candle-onnx | |
parent | 718671a0d5b751458033fb6425fb518ca4dc3b5f (diff) | |
download | candle-798e0335cd2c4661f0fd0429cdf06abe3b45f4ea.tar.gz candle-798e0335cd2c4661f0fd0429cdf06abe3b45f4ea.tar.bz2 candle-798e0335cd2c4661f0fd0429cdf06abe3b45f4ea.zip |
Handle more tensor shapes in onnx "Gather" operation (#2026)
* Handle more tensor shapes in onnx "Gather" operation
* Add more tests
* Add comment
* Fix typo
Diffstat (limited to 'candle-onnx')
-rw-r--r-- | candle-onnx/src/eval.rs | 30 | ||||
-rw-r--r-- | candle-onnx/tests/ops.rs | 131 |
2 files changed, 152 insertions, 9 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 15cadf1d..f7cae31c 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -508,17 +508,33 @@ pub fn simple_eval( values.insert(node.output[0].clone(), xs); } "Gather" => { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather let xs = get(&node.input[0])?; let indices = get(&node.input[1])?; let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0); let axis = xs.normalize_axis(axis)?; - // TODO: Provide an op to handle the ONNX generalized gather op ideally in a - // differentiable way. - let xs = if indices.rank() == 0 { - let index = indices.to_vec0::<i64>()? as usize; - xs.narrow(axis, index, 1)?.squeeze(axis)? - } else { - todo!("implement gather for {xs:?} {indices:?} axis {axis}") + + // In Pytorch or Numpy this can be done by indexing the xs tensor using the indices + // tensor directly, but candle does not support tensor indexing at the moment, so + // some workarounds must be done. + let xs = match indices.dims() { + [] => { + let index = indices.to_vec0::<i64>()? as usize; + xs.narrow(axis, index, 1)?.squeeze(axis)? + } + [_] => xs.index_select(indices, axis)?, + [first, _] => { + let mut v = Vec::with_capacity(*first); + for i in 0..*first { + v.push(xs.index_select(&indices.get(i)?, axis)?) + } + Tensor::stack(&v, axis)? + } + _ => { + // TODO: Provide an op to handle the ONNX generalized gather op ideally in a + // differentiable way. + todo!("implement gather for {xs:?} {indices:?} axis {axis}") + } }; values.insert(node.output[0].clone(), xs); } diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index a686f198..18cd53c9 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -4,7 +4,7 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use candle::{Device, Result, Tensor}; +use candle::{Device, NdArray, Result, Tensor}; use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto}; use std::collections::HashMap; @@ -829,7 +829,134 @@ fn test_flatten_operation() -> Result<()> { // #[test] // "Gather" -// #[test] +#[test] +fn test_gather_operation() -> Result<()> { + // test taken from https://onnx.ai/onnx/operators/onnx__Gather.html#summary. + test( + &[[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]], + &[[0i64, 1], [1, 2]], + 0, + &[[[1.0, 1.2], [2.3, 3.4]], [[2.3, 3.4], [4.5, 5.7]]], + )?; + + // test taken from https://onnx.ai/onnx/operators/onnx__Gather.html#summary. + test( + &[[1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9]], + &[[0i64, 2]], + 1, + &[[[1.0, 1.9]], [[2.3, 3.9]], [[4.5, 5.9]]], + )?; + + // all the tests below are generated from numpy.take, which works like + // onnx's Gather operation. + test(&[1.0, 2.0, 3.0, 4.0], 3i64, 0, 4.0)?; + + test(&[[1.0, 2.0, 3.0, 4.0]], 3i64, 1, &[4.0])?; + + test( + &[[1.0], [2.0], [3.0], [4.0]], + &[3i64, 2], + 0, + &[[4.0], [3.0]], + )?; + + test( + &[ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + [[13.0, 14.0], [15.0, 16.0]], + ], + 1i64, + 0, + &[[5.0, 6.0], [7.0, 8.0]], + )?; + + test( + &[ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + [[13.0, 14.0], [15.0, 16.0]], + ], + &[1i64, 0], + 0, + &[[[5.0, 6.0], [7.0, 8.0]], [[1.0, 2.0], [3.0, 4.0]]], + )?; + + fn test( + data: impl NdArray, + indices: impl NdArray, + axis: i64, + expected: impl NdArray, + ) -> Result<()> { + let att_axis = AttributeProto { + name: "axis".to_string(), + ref_attr_name: "axis".to_string(), + i: axis, + doc_string: "axis".to_string(), + r#type: 2, + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Gather".to_string(), + domain: "".to_string(), + attribute: vec![att_axis], + input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let mut inputs: HashMap<String, Tensor> = HashMap::new(); + inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); + inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let expected = Tensor::new(expected, &Device::Cpu)?; + match expected.dims().len() { + 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?), + 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?), + 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?), + 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?), + _ => unreachable!(), + }; + + Ok(()) + } + Ok(()) +} // "Shape" #[test] |