summaryrefslogtreecommitdiff
path: root/candle-onnx/src/eval.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-onnx/src/eval.rs')
-rw-r--r--candle-onnx/src/eval.rs30
1 files changed, 23 insertions, 7 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index 15cadf1d..f7cae31c 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -508,17 +508,33 @@ pub fn simple_eval(
values.insert(node.output[0].clone(), xs);
}
"Gather" => {
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather
let xs = get(&node.input[0])?;
let indices = get(&node.input[1])?;
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = xs.normalize_axis(axis)?;
- // TODO: Provide an op to handle the ONNX generalized gather op ideally in a
- // differentiable way.
- let xs = if indices.rank() == 0 {
- let index = indices.to_vec0::<i64>()? as usize;
- xs.narrow(axis, index, 1)?.squeeze(axis)?
- } else {
- todo!("implement gather for {xs:?} {indices:?} axis {axis}")
+
+ // In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
+ // tensor directly, but candle does not support tensor indexing at the moment, so
+ // some workarounds must be done.
+ let xs = match indices.dims() {
+ [] => {
+ let index = indices.to_vec0::<i64>()? as usize;
+ xs.narrow(axis, index, 1)?.squeeze(axis)?
+ }
+ [_] => xs.index_select(indices, axis)?,
+ [first, _] => {
+ let mut v = Vec::with_capacity(*first);
+ for i in 0..*first {
+ v.push(xs.index_select(&indices.get(i)?, axis)?)
+ }
+ Tensor::stack(&v, axis)?
+ }
+ _ => {
+ // TODO: Provide an op to handle the ONNX generalized gather op ideally in a
+ // differentiable way.
+ todo!("implement gather for {xs:?} {indices:?} axis {axis}")
+ }
};
values.insert(node.output[0].clone(), xs);
}