From dc68c130e443bf5b7f78a90fb9b0bad4ce79ead3 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 4 Nov 2023 15:10:14 +0100 Subject: Support more ONNX ops. (#1267) * Add LogSoftmax. * Support for Transpose. --- candle-onnx/src/eval.rs | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) (limited to 'candle-onnx/src') diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 2a80f8c1..73376fbe 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -59,6 +59,26 @@ pub fn simple_eval( Ok(dt.i) } }; + let get_attr_is = |name: &str| match node.attribute.iter().find(|attr| attr.name == name) { + None => { + bail!( + "cannot find the '{name}' attribute in '{}' for {}", + node.op_type, + node.name + ) + } + Some(dt) => { + match dt.r#type() { + AttributeType::Ints => (), + rtype => bail!( + "unsupported type {rtype:?} for '{name}' attribute in '{}' for {}", + node.op_type, + node.name + ), + } + Ok(dt.ints.as_slice()) + } + }; // TODO: Validate node.input for each operator. match node.op_type.as_str() { "Add" => { @@ -114,6 +134,24 @@ pub fn simple_eval( let output = input0.reshape(input1)?; values.insert(node.output[0].clone(), output); } + "LogSoftmax" => { + let input = get(&node.input[0])?; + let output = match get_attr_i("axis") { + Err(_) => candle_nn::ops::softmax_last_dim(input)?, + Ok(axis) => { + let num_axis = input.rank() as i64; + let axis = if axis >= 0 { + axis as usize + } else if axis < -num_axis { + bail!("wrong axis in concat {axis} for shape {:?}", input.shape()) + } else { + (num_axis - axis) as usize + }; + candle_nn::ops::log_softmax(input, axis)? + } + }; + values.insert(node.output[0].clone(), output); + } "Softmax" => { let input = get(&node.input[0])?; let output = match get_attr_i("axis") { @@ -132,6 +170,17 @@ pub fn simple_eval( }; values.insert(node.output[0].clone(), output); } + "Transpose" => { + let input = get(&node.input[0])?; + let output = match get_attr_is("perm") { + Err(_) => input.t()?, + Ok(perm) => { + let perm = perm.iter().map(|&v| v as usize).collect::>(); + input.permute(perm)? + } + }; + values.insert(node.output[0].clone(), output); + } "Concat" => { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat let inputs = node -- cgit v1.2.3