diff options
author | YangNianYi <y790174683@163.com> | 2023-11-09 06:28:21 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-08 23:28:21 +0100 |
commit | 73d02f4f57c788c43f3e11991635bc15701c25c0 (patch) | |
tree | e8d6283de0de6bf13b397f4143326ea4a310ab6d /candle-onnx/src | |
parent | f772213e844fdfcc8dbaf662fc11819f4028dc78 (diff) | |
download | candle-73d02f4f57c788c43f3e11991635bc15701c25c0.tar.gz candle-73d02f4f57c788c43f3e11991635bc15701c25c0.tar.bz2 candle-73d02f4f57c788c43f3e11991635bc15701c25c0.zip |
fix: negative axis (#1296)
* fix: negative axis
* Use normalize_axis.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-onnx/src')
-rw-r--r-- | candle-onnx/src/eval.rs | 33 |
1 files changed, 4 insertions, 29 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index b7e325e1..123e4c19 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -298,14 +298,7 @@ pub fn simple_eval( let output = match get_attr_opt::<i64>(node, "axis")? { None => candle_nn::ops::softmax_last_dim(input)?, Some(&axis) => { - let num_axis = input.rank() as i64; - let axis = if axis >= 0 { - axis as usize - } else if axis < -num_axis { - bail!("wrong axis in concat {axis} for shape {:?}", input.shape()) - } else { - (num_axis - axis) as usize - }; + let axis = input.normalize_axis(axis)?; candle_nn::ops::log_softmax(input, axis)? } }; @@ -316,14 +309,7 @@ pub fn simple_eval( let output = match get_attr_opt::<i64>(node, "axis")? { None => candle_nn::ops::softmax_last_dim(input)?, Some(&axis) => { - let num_axis = input.rank() as i64; - let axis = if axis >= 0 { - axis as usize - } else if axis < -num_axis { - bail!("wrong axis in concat {axis} for shape {:?}", input.shape()) - } else { - (num_axis - axis) as usize - }; + let axis = input.normalize_axis(axis)?; candle_nn::ops::softmax(input, axis)? } }; @@ -666,21 +652,10 @@ pub fn simple_eval( .map(|n| Ok(get(n.as_str())?.clone())) .collect::<Result<Vec<Value>>>()?; let axis: i64 = *get_attr(node, "axis")?; - let num_axis = if inputs.is_empty() { + if inputs.is_empty() { bail!("empty concat") - } else { - inputs[0].rank() as i64 - }; - let axis = if axis >= 0 { - axis as usize - } else if axis < -num_axis { - bail!( - "wrong axis in concat {axis} for shape {:?}", - inputs[0].shape() - ) - } else { - (num_axis - axis) as usize }; + let axis = inputs[0].normalize_axis(axis)?; let output = Tensor::cat(&inputs, axis)?; values.insert(node.output[0].clone(), output); } |