summaryrefslogtreecommitdiff
path: root/candle-onnx
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-11-17 22:03:40 +0000
committerGitHub <noreply@github.com>2023-11-17 22:03:40 +0000
commitd31f11035fb69d06bf22194d918a5efc45c8ab37 (patch)
tree85668896188de4035d9cbb1d898aff92afdc35aa /candle-onnx
parent9ab3f9729fc1444687578a9dc913760b0d8d9963 (diff)
downloadcandle-d31f11035fb69d06bf22194d918a5efc45c8ab37.tar.gz
candle-d31f11035fb69d06bf22194d918a5efc45c8ab37.tar.bz2
candle-d31f11035fb69d06bf22194d918a5efc45c8ab37.zip
Support for CumSum in ONNX models. (#1340)
Diffstat (limited to 'candle-onnx')
-rw-r--r--candle-onnx/src/eval.rs19
1 files changed, 19 insertions, 0 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index 123e4c19..684776c2 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -741,6 +741,25 @@ pub fn simple_eval(
let output = input.to_dtype(dtype)?;
values.insert(node.output[0].clone(), output);
}
+ // https://github.com/onnx/onnx/blob/main/docs/Operators.md#CumSum
+ "CumSum" => {
+ let exclusive = get_attr_opt::<i64>(node, "exclusive")?
+ .copied()
+ .unwrap_or(0);
+ let reverse = get_attr_opt::<i64>(node, "reverse")?.copied().unwrap_or(0);
+ if exclusive != 0 {
+ bail!("only exclusive == 0 is supported in CumSum")
+ }
+ if reverse != 0 {
+ bail!("only reverse == 0 is supported in CumSum")
+ }
+ let input = get(&node.input[0])?;
+ let axis = get(&node.input[1])?
+ .to_dtype(DType::U32)?
+ .to_vec0::<u32>()?;
+ let output = input.cumsum(axis as usize)?;
+ values.insert(node.output[0].clone(), output);
+ }
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}