summaryrefslogtreecommitdiff
path: root/candle-onnx/src
diff options
context:
space:
mode:
authorYangNianYi <y790174683@163.com>2023-11-09 06:28:21 +0800
committerGitHub <noreply@github.com>2023-11-08 23:28:21 +0100
commit73d02f4f57c788c43f3e11991635bc15701c25c0 (patch)
treee8d6283de0de6bf13b397f4143326ea4a310ab6d /candle-onnx/src
parentf772213e844fdfcc8dbaf662fc11819f4028dc78 (diff)
downloadcandle-73d02f4f57c788c43f3e11991635bc15701c25c0.tar.gz
candle-73d02f4f57c788c43f3e11991635bc15701c25c0.tar.bz2
candle-73d02f4f57c788c43f3e11991635bc15701c25c0.zip
fix: negative axis (#1296)
* fix: negative axis * Use normalize_axis. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-onnx/src')
-rw-r--r--candle-onnx/src/eval.rs33
1 files changed, 4 insertions, 29 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index b7e325e1..123e4c19 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -298,14 +298,7 @@ pub fn simple_eval(
let output = match get_attr_opt::<i64>(node, "axis")? {
None => candle_nn::ops::softmax_last_dim(input)?,
Some(&axis) => {
- let num_axis = input.rank() as i64;
- let axis = if axis >= 0 {
- axis as usize
- } else if axis < -num_axis {
- bail!("wrong axis in concat {axis} for shape {:?}", input.shape())
- } else {
- (num_axis - axis) as usize
- };
+ let axis = input.normalize_axis(axis)?;
candle_nn::ops::log_softmax(input, axis)?
}
};
@@ -316,14 +309,7 @@ pub fn simple_eval(
let output = match get_attr_opt::<i64>(node, "axis")? {
None => candle_nn::ops::softmax_last_dim(input)?,
Some(&axis) => {
- let num_axis = input.rank() as i64;
- let axis = if axis >= 0 {
- axis as usize
- } else if axis < -num_axis {
- bail!("wrong axis in concat {axis} for shape {:?}", input.shape())
- } else {
- (num_axis - axis) as usize
- };
+ let axis = input.normalize_axis(axis)?;
candle_nn::ops::softmax(input, axis)?
}
};
@@ -666,21 +652,10 @@ pub fn simple_eval(
.map(|n| Ok(get(n.as_str())?.clone()))
.collect::<Result<Vec<Value>>>()?;
let axis: i64 = *get_attr(node, "axis")?;
- let num_axis = if inputs.is_empty() {
+ if inputs.is_empty() {
bail!("empty concat")
- } else {
- inputs[0].rank() as i64
- };
- let axis = if axis >= 0 {
- axis as usize
- } else if axis < -num_axis {
- bail!(
- "wrong axis in concat {axis} for shape {:?}",
- inputs[0].shape()
- )
- } else {
- (num_axis - axis) as usize
};
+ let axis = inputs[0].normalize_axis(axis)?;
let output = Tensor::cat(&inputs, axis)?;
values.insert(node.output[0].clone(), output);
}