diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-11-04 15:10:14 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-04 15:10:14 +0100 |
commit | dc68c130e443bf5b7f78a90fb9b0bad4ce79ead3 (patch) | |
tree | 6c67e2fd21f3b78072b3a76da26d240a3d497a82 /candle-onnx/src | |
parent | bc9a1bf2399243c659b1e902b14e8572a12ec15b (diff) | |
download | candle-dc68c130e443bf5b7f78a90fb9b0bad4ce79ead3.tar.gz candle-dc68c130e443bf5b7f78a90fb9b0bad4ce79ead3.tar.bz2 candle-dc68c130e443bf5b7f78a90fb9b0bad4ce79ead3.zip |
Support more ONNX ops. (#1267)
* Add LogSoftmax.
* Support for Transpose.
Diffstat (limited to 'candle-onnx/src')
-rw-r--r-- | candle-onnx/src/eval.rs | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 2a80f8c1..73376fbe 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -59,6 +59,26 @@ pub fn simple_eval( Ok(dt.i) } }; + let get_attr_is = |name: &str| match node.attribute.iter().find(|attr| attr.name == name) { + None => { + bail!( + "cannot find the '{name}' attribute in '{}' for {}", + node.op_type, + node.name + ) + } + Some(dt) => { + match dt.r#type() { + AttributeType::Ints => (), + rtype => bail!( + "unsupported type {rtype:?} for '{name}' attribute in '{}' for {}", + node.op_type, + node.name + ), + } + Ok(dt.ints.as_slice()) + } + }; // TODO: Validate node.input for each operator. match node.op_type.as_str() { "Add" => { @@ -114,6 +134,24 @@ pub fn simple_eval( let output = input0.reshape(input1)?; values.insert(node.output[0].clone(), output); } + "LogSoftmax" => { + let input = get(&node.input[0])?; + let output = match get_attr_i("axis") { + Err(_) => candle_nn::ops::softmax_last_dim(input)?, + Ok(axis) => { + let num_axis = input.rank() as i64; + let axis = if axis >= 0 { + axis as usize + } else if axis < -num_axis { + bail!("wrong axis in concat {axis} for shape {:?}", input.shape()) + } else { + (num_axis - axis) as usize + }; + candle_nn::ops::log_softmax(input, axis)? + } + }; + values.insert(node.output[0].clone(), output); + } "Softmax" => { let input = get(&node.input[0])?; let output = match get_attr_i("axis") { @@ -132,6 +170,17 @@ pub fn simple_eval( }; values.insert(node.output[0].clone(), output); } + "Transpose" => { + let input = get(&node.input[0])?; + let output = match get_attr_is("perm") { + Err(_) => input.t()?, + Ok(perm) => { + let perm = perm.iter().map(|&v| v as usize).collect::<Vec<_>>(); + input.permute(perm)? + } + }; + values.insert(node.output[0].clone(), output); + } "Concat" => { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat let inputs = node |