From 73d02f4f57c788c43f3e11991635bc15701c25c0 Mon Sep 17 00:00:00 2001 From: YangNianYi Date: Thu, 9 Nov 2023 06:28:21 +0800 Subject: fix: negative axis (#1296) * fix: negative axis * Use normalize_axis. --------- Co-authored-by: Laurent --- candle-onnx/src/eval.rs | 33 ++++----------------------------- 1 file changed, 4 insertions(+), 29 deletions(-) (limited to 'candle-onnx/src') diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index b7e325e1..123e4c19 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -298,14 +298,7 @@ pub fn simple_eval( let output = match get_attr_opt::(node, "axis")? { None => candle_nn::ops::softmax_last_dim(input)?, Some(&axis) => { - let num_axis = input.rank() as i64; - let axis = if axis >= 0 { - axis as usize - } else if axis < -num_axis { - bail!("wrong axis in concat {axis} for shape {:?}", input.shape()) - } else { - (num_axis - axis) as usize - }; + let axis = input.normalize_axis(axis)?; candle_nn::ops::log_softmax(input, axis)? } }; @@ -316,14 +309,7 @@ pub fn simple_eval( let output = match get_attr_opt::(node, "axis")? { None => candle_nn::ops::softmax_last_dim(input)?, Some(&axis) => { - let num_axis = input.rank() as i64; - let axis = if axis >= 0 { - axis as usize - } else if axis < -num_axis { - bail!("wrong axis in concat {axis} for shape {:?}", input.shape()) - } else { - (num_axis - axis) as usize - }; + let axis = input.normalize_axis(axis)?; candle_nn::ops::softmax(input, axis)? } }; @@ -666,21 +652,10 @@ pub fn simple_eval( .map(|n| Ok(get(n.as_str())?.clone())) .collect::>>()?; let axis: i64 = *get_attr(node, "axis")?; - let num_axis = if inputs.is_empty() { + if inputs.is_empty() { bail!("empty concat") - } else { - inputs[0].rank() as i64 - }; - let axis = if axis >= 0 { - axis as usize - } else if axis < -num_axis { - bail!( - "wrong axis in concat {axis} for shape {:?}", - inputs[0].shape() - ) - } else { - (num_axis - axis) as usize }; + let axis = inputs[0].normalize_axis(axis)?; let output = Tensor::cat(&inputs, axis)?; values.insert(node.output[0].clone(), output); } -- cgit v1.2.3