summaryrefslogtreecommitdiff
path: root/candle-onnx
diff options
context:
space:
mode:
Diffstat (limited to 'candle-onnx')
-rw-r--r--candle-onnx/src/eval.rs30
-rw-r--r--candle-onnx/tests/ops.rs131
2 files changed, 152 insertions, 9 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index 15cadf1d..f7cae31c 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -508,17 +508,33 @@ pub fn simple_eval(
values.insert(node.output[0].clone(), xs);
}
"Gather" => {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather
let xs = get(&node.input[0])?;
let indices = get(&node.input[1])?;
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = xs.normalize_axis(axis)?;
- // TODO: Provide an op to handle the ONNX generalized gather op ideally in a
- // differentiable way.
- let xs = if indices.rank() == 0 {
- let index = indices.to_vec0::<i64>()? as usize;
- xs.narrow(axis, index, 1)?.squeeze(axis)?
- } else {
- todo!("implement gather for {xs:?} {indices:?} axis {axis}")
+
+ // In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
+ // tensor directly, but candle does not support tensor indexing at the moment, so
+ // some workarounds must be done.
+ let xs = match indices.dims() {
+ [] => {
+ let index = indices.to_vec0::<i64>()? as usize;
+ xs.narrow(axis, index, 1)?.squeeze(axis)?
+ }
+ [_] => xs.index_select(indices, axis)?,
+ [first, _] => {
+ let mut v = Vec::with_capacity(*first);
+ for i in 0..*first {
+ v.push(xs.index_select(&indices.get(i)?, axis)?)
+ }
+ Tensor::stack(&v, axis)?
+ }
+ _ => {
+ // TODO: Provide an op to handle the ONNX generalized gather op ideally in a
+ // differentiable way.
+ todo!("implement gather for {xs:?} {indices:?} axis {axis}")
+ }
};
values.insert(node.output[0].clone(), xs);
}
diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs
index a686f198..18cd53c9 100644
--- a/candle-onnx/tests/ops.rs
+++ b/candle-onnx/tests/ops.rs
@@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
-use candle::{Device, Result, Tensor};
+use candle::{Device, NdArray, Result, Tensor};
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
use std::collections::HashMap;
@@ -829,7 +829,134 @@ fn test_flatten_operation() -> Result<()> {
// #[test]
// "Gather"
-// #[test]
+#[test]
+fn test_gather_operation() -> Result<()> {
+ // test taken from https://onnx.ai/onnx/operators/onnx__Gather.html#summary.
+ test(
+ &[[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]],
+ &[[0i64, 1], [1, 2]],
+ 0,
+ &[[[1.0, 1.2], [2.3, 3.4]], [[2.3, 3.4], [4.5, 5.7]]],
+ )?;
+
+ // test taken from https://onnx.ai/onnx/operators/onnx__Gather.html#summary.
+ test(
+ &[[1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9]],
+ &[[0i64, 2]],
+ 1,
+ &[[[1.0, 1.9]], [[2.3, 3.9]], [[4.5, 5.9]]],
+ )?;
+
+ // all the tests below are generated from numpy.take, which works like
+ // onnx's Gather operation.
+ test(&[1.0, 2.0, 3.0, 4.0], 3i64, 0, 4.0)?;
+
+ test(&[[1.0, 2.0, 3.0, 4.0]], 3i64, 1, &[4.0])?;
+
+ test(
+ &[[1.0], [2.0], [3.0], [4.0]],
+ &[3i64, 2],
+ 0,
+ &[[4.0], [3.0]],
+ )?;
+
+ test(
+ &[
+ [[1.0, 2.0], [3.0, 4.0]],
+ [[5.0, 6.0], [7.0, 8.0]],
+ [[9.0, 10.0], [11.0, 12.0]],
+ [[13.0, 14.0], [15.0, 16.0]],
+ ],
+ 1i64,
+ 0,
+ &[[5.0, 6.0], [7.0, 8.0]],
+ )?;
+
+ test(
+ &[
+ [[1.0, 2.0], [3.0, 4.0]],
+ [[5.0, 6.0], [7.0, 8.0]],
+ [[9.0, 10.0], [11.0, 12.0]],
+ [[13.0, 14.0], [15.0, 16.0]],
+ ],
+ &[1i64, 0],
+ 0,
+ &[[[5.0, 6.0], [7.0, 8.0]], [[1.0, 2.0], [3.0, 4.0]]],
+ )?;
+
+ fn test(
+ data: impl NdArray,
+ indices: impl NdArray,
+ axis: i64,
+ expected: impl NdArray,
+ ) -> Result<()> {
+ let att_axis = AttributeProto {
+ name: "axis".to_string(),
+ ref_attr_name: "axis".to_string(),
+ i: axis,
+ doc_string: "axis".to_string(),
+ r#type: 2,
+ f: 0.0,
+ s: vec![],
+ t: None,
+ g: None,
+ sparse_tensor: None,
+ tp: None,
+ floats: vec![],
+ ints: vec![],
+ strings: vec![],
+ tensors: vec![],
+ graphs: vec![],
+ sparse_tensors: vec![],
+ type_protos: vec![],
+ };
+
+ let manual_graph = create_model_proto_with_graph(Some(GraphProto {
+ node: vec![NodeProto {
+ op_type: "Gather".to_string(),
+ domain: "".to_string(),
+ attribute: vec![att_axis],
+ input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
+ output: vec![OUTPUT_Z.to_string()],
+ name: "".to_string(),
+ doc_string: "".to_string(),
+ }],
+ name: "".to_string(),
+ initializer: vec![],
+ input: vec![],
+ output: vec![ValueInfoProto {
+ name: OUTPUT_Z.to_string(),
+ doc_string: "".to_string(),
+ r#type: None,
+ }],
+ value_info: vec![],
+ doc_string: "".to_string(),
+ sparse_initializer: vec![],
+ quantization_annotation: vec![],
+ }));
+
+ let mut inputs: HashMap<String, Tensor> = HashMap::new();
+ inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
+ inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?);
+
+ let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
+ assert_eq!(eval.len(), 1);
+
+ let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
+
+ let expected = Tensor::new(expected, &Device::Cpu)?;
+ match expected.dims().len() {
+ 0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
+ 1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
+ 2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
+ 3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
+ _ => unreachable!(),
+ };
+
+ Ok(())
+ }
+ Ok(())
+}
// "Shape"
#[test]