summaryrefslogtreecommitdiff
path: root/candle-onnx/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-11-04 15:10:14 +0100
committerGitHub <noreply@github.com>2023-11-04 15:10:14 +0100
commitdc68c130e443bf5b7f78a90fb9b0bad4ce79ead3 (patch)
tree6c67e2fd21f3b78072b3a76da26d240a3d497a82 /candle-onnx/src
parentbc9a1bf2399243c659b1e902b14e8572a12ec15b (diff)
downloadcandle-dc68c130e443bf5b7f78a90fb9b0bad4ce79ead3.tar.gz
candle-dc68c130e443bf5b7f78a90fb9b0bad4ce79ead3.tar.bz2
candle-dc68c130e443bf5b7f78a90fb9b0bad4ce79ead3.zip
Support more ONNX ops. (#1267)
* Add LogSoftmax. * Support for Transpose.
Diffstat (limited to 'candle-onnx/src')
-rw-r--r--candle-onnx/src/eval.rs49
1 files changed, 49 insertions, 0 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index 2a80f8c1..73376fbe 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -59,6 +59,26 @@ pub fn simple_eval(
Ok(dt.i)
}
};
+ let get_attr_is = |name: &str| match node.attribute.iter().find(|attr| attr.name == name) {
+ None => {
+ bail!(
+ "cannot find the '{name}' attribute in '{}' for {}",
+ node.op_type,
+ node.name
+ )
+ }
+ Some(dt) => {
+ match dt.r#type() {
+ AttributeType::Ints => (),
+ rtype => bail!(
+ "unsupported type {rtype:?} for '{name}' attribute in '{}' for {}",
+ node.op_type,
+ node.name
+ ),
+ }
+ Ok(dt.ints.as_slice())
+ }
+ };
// TODO: Validate node.input for each operator.
match node.op_type.as_str() {
"Add" => {
@@ -114,6 +134,24 @@ pub fn simple_eval(
let output = input0.reshape(input1)?;
values.insert(node.output[0].clone(), output);
}
+ "LogSoftmax" => {
+ let input = get(&node.input[0])?;
+ let output = match get_attr_i("axis") {
+ Err(_) => candle_nn::ops::softmax_last_dim(input)?,
+ Ok(axis) => {
+ let num_axis = input.rank() as i64;
+ let axis = if axis >= 0 {
+ axis as usize
+ } else if axis < -num_axis {
+ bail!("wrong axis in concat {axis} for shape {:?}", input.shape())
+ } else {
+ (num_axis - axis) as usize
+ };
+ candle_nn::ops::log_softmax(input, axis)?
+ }
+ };
+ values.insert(node.output[0].clone(), output);
+ }
"Softmax" => {
let input = get(&node.input[0])?;
let output = match get_attr_i("axis") {
@@ -132,6 +170,17 @@ pub fn simple_eval(
};
values.insert(node.output[0].clone(), output);
}
+ "Transpose" => {
+ let input = get(&node.input[0])?;
+ let output = match get_attr_is("perm") {
+ Err(_) => input.t()?,
+ Ok(perm) => {
+ let perm = perm.iter().map(|&v| v as usize).collect::<Vec<_>>();
+ input.permute(perm)?
+ }
+ };
+ values.insert(node.output[0].clone(), output);
+ }
"Concat" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat
let inputs = node