diff options
Diffstat (limited to 'candle-onnx/src/eval.rs')
-rw-r--r-- | candle-onnx/src/eval.rs | 30 |
1 files changed, 23 insertions, 7 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 15cadf1d..f7cae31c 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -508,17 +508,33 @@ pub fn simple_eval( values.insert(node.output[0].clone(), xs); } "Gather" => { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather let xs = get(&node.input[0])?; let indices = get(&node.input[1])?; let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0); let axis = xs.normalize_axis(axis)?; - // TODO: Provide an op to handle the ONNX generalized gather op ideally in a - // differentiable way. - let xs = if indices.rank() == 0 { - let index = indices.to_vec0::<i64>()? as usize; - xs.narrow(axis, index, 1)?.squeeze(axis)? - } else { - todo!("implement gather for {xs:?} {indices:?} axis {axis}") + + // In Pytorch or Numpy this can be done by indexing the xs tensor using the indices + // tensor directly, but candle does not support tensor indexing at the moment, so + // some workarounds must be done. + let xs = match indices.dims() { + [] => { + let index = indices.to_vec0::<i64>()? as usize; + xs.narrow(axis, index, 1)?.squeeze(axis)? + } + [_] => xs.index_select(indices, axis)?, + [first, _] => { + let mut v = Vec::with_capacity(*first); + for i in 0..*first { + v.push(xs.index_select(&indices.get(i)?, axis)?) + } + Tensor::stack(&v, axis)? + } + _ => { + // TODO: Provide an op to handle the ONNX generalized gather op ideally in a + // differentiable way. + todo!("implement gather for {xs:?} {indices:?} axis {axis}") + } }; values.insert(node.output[0].clone(), xs); } |