summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-onnx/src/eval.rs33
-rw-r--r--candle-onnx/tests/ops.rs104
2 files changed, 64 insertions, 73 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs
index 31eb62cd..e72002e6 100644
--- a/candle-onnx/src/eval.rs
+++ b/candle-onnx/src/eval.rs
@@ -1032,11 +1032,18 @@ pub fn simple_eval(
let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0);
let rank_i64: i64 = input.rank().try_into().unwrap();
if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 {
- bail!("axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]", axis_i64, -rank_i64, rank_i64-1)
+ bail!(
+ "axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]",
+ axis_i64,
+ -rank_i64,
+ rank_i64 - 1
+ )
}
let axis = input.normalize_axis(axis_i64)?;
let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1);
- let select_last_index: i64 = get_attr_opt(node, "select_last_index")?.copied().unwrap_or(0);
+ let select_last_index: i64 = get_attr_opt(node, "select_last_index")?
+ .copied()
+ .unwrap_or(0);
if select_last_index == 1 {
bail!("select_last_index for ArgMin is currently not supported")
}
@@ -1044,7 +1051,8 @@ pub fn simple_eval(
input.argmin_keepdim(axis)?
} else {
input.argmin(axis)?
- }.to_dtype(DType::I64)?;
+ }
+ .to_dtype(DType::I64)?;
values.insert(node.output[0].clone(), output);
}
"ArgMax" => {
@@ -1052,11 +1060,18 @@ pub fn simple_eval(
let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0);
let rank_i64: i64 = input.rank().try_into().unwrap();
if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 {
- bail!("axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]", axis_i64, -rank_i64, rank_i64-1)
+ bail!(
+ "axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]",
+ axis_i64,
+ -rank_i64,
+ rank_i64 - 1
+ )
}
let axis = input.normalize_axis(axis_i64)?;
let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1);
- let select_last_index: i64 = get_attr_opt(node, "select_last_index")?.copied().unwrap_or(0);
+ let select_last_index: i64 = get_attr_opt(node, "select_last_index")?
+ .copied()
+ .unwrap_or(0);
if select_last_index == 1 {
bail!("select_last_index for ArgMin is currently not supported")
}
@@ -1064,7 +1079,8 @@ pub fn simple_eval(
input.argmax_keepdim(axis)?
} else {
input.argmax(axis)?
- }.to_dtype(DType::I64)?;
+ }
+ .to_dtype(DType::I64)?;
values.insert(node.output[0].clone(), output);
}
"LeakyRelu" => {
@@ -1072,7 +1088,10 @@ pub fn simple_eval(
let dt = input.dtype();
match dt {
DType::U8 | DType::U32 | DType::I64 => {
- bail!("unsupported dtype {}, only float types are allowed for LeakyRelu", dt.as_str())
+ bail!(
+ "unsupported dtype {}, only float types are allowed for LeakyRelu",
+ dt.as_str()
+ )
}
DType::BF16 | DType::F16 | DType::F32 | DType::F64 => {}
}
diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs
index ffafd7a7..2e60d22c 100644
--- a/candle-onnx/tests/ops.rs
+++ b/candle-onnx/tests/ops.rs
@@ -2715,51 +2715,31 @@ fn test_argmin() -> Result<()> {
// tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-7
// default_axes_keepdims
test(
- &[
- [2u32, 1u32],
- [3u32, 10u32]
- ],
+ &[[2u32, 1u32], [3u32, 10u32]],
None,
Some(1),
None,
- &[
- [0i64, 0i64],
- ],
+ &[[0i64, 0i64]],
)?;
// keepdims
test(
- &[
- [2u32, 1u32],
- [3u32, 10u32]
- ],
+ &[[2u32, 1u32], [3u32, 10u32]],
Some(1),
Some(1),
None,
- &[
- [1i64],
- [0i64]
- ],
+ &[[1i64], [0i64]],
)?;
// // negative_axis_keepdims
test(
- &[
- [2u32, 1u32],
- [3u32, 10u32]
- ],
+ &[[2u32, 1u32], [3u32, 10u32]],
Some(-1),
Some(1),
None,
- &[
- [1i64],
- [0i64]
- ],
+ &[[1i64], [0i64]],
)?;
// no_keepdims
test(
- &[
- [2u32, 1u32],
- [3u32, 10u32]
- ],
+ &[[2u32, 1u32], [3u32, 10u32]],
None,
Some(0),
None,
@@ -2771,7 +2751,7 @@ fn test_argmin() -> Result<()> {
[0.1139, 0.2254, -0.1381, 0.3687],
[1.0100, -1.1975, -0.0102, -0.4732],
[-0.9240, 0.1207, -0.7506, -1.0213],
- [1.7809, -1.2960, 0.9384, 0.1438]
+ [1.7809, -1.2960, 0.9384, 0.1438],
],
Some(1),
Some(0),
@@ -2783,14 +2763,20 @@ fn test_argmin() -> Result<()> {
[0.1139, 0.2254, -0.1381, 0.3687],
[1.0100, -1.1975, -0.0102, -0.4732],
[-0.9240, 0.1207, -0.7506, -1.0213],
- [1.7809, -1.2960, 0.9384, 0.1438]
+ [1.7809, -1.2960, 0.9384, 0.1438],
],
Some(1),
None,
None,
&[[2i64], [1i64], [3i64], [1i64]],
)?;
- fn test(data: impl NdArray, axis: Option<i64>, keepdims: Option<i64>, select_last_index: Option<i64>, expected: impl NdArray) -> Result<()> {
+ fn test(
+ data: impl NdArray,
+ axis: Option<i64>,
+ keepdims: Option<i64>,
+ select_last_index: Option<i64>,
+ expected: impl NdArray,
+ ) -> Result<()> {
let att_axis = AttributeProto {
name: "axis".to_string(),
ref_attr_name: "axis".to_string(),
@@ -2911,51 +2897,31 @@ fn test_argmax() -> Result<()> {
// tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-6
// default_axes_keepdims
test(
- &[
- [2u32, 1u32],
- [3u32, 10u32]
- ],
+ &[[2u32, 1u32], [3u32, 10u32]],
None,
Some(1),
None,
- &[
- [1i64, 1i64],
- ],
+ &[[1i64, 1i64]],
)?;
// keepdims
test(
- &[
- [2u32, 1u32],
- [3u32, 10u32]
- ],
+ &[[2u32, 1u32], [3u32, 10u32]],
Some(1),
Some(1),
None,
- &[
- [0i64],
- [1i64]
- ],
+ &[[0i64], [1i64]],
)?;
// // negative_axis_keepdims
test(
- &[
- [2u32, 1u32],
- [3u32, 10u32]
- ],
+ &[[2u32, 1u32], [3u32, 10u32]],
Some(-1),
Some(1),
None,
- &[
- [0i64],
- [1i64]
- ],
+ &[[0i64], [1i64]],
)?;
// no_keepdims
test(
- &[
- [2u32, 1u32],
- [3u32, 10u32]
- ],
+ &[[2u32, 1u32], [3u32, 10u32]],
None,
Some(0),
None,
@@ -2967,7 +2933,7 @@ fn test_argmax() -> Result<()> {
[1.3398, 0.2663, -0.2686, 0.2450],
[-0.7401, -0.8805, -0.3402, -1.1936],
[0.4907, -1.3948, -1.0691, -0.3132],
- [-1.6092, 0.5419, -0.2993, 0.3195]
+ [-1.6092, 0.5419, -0.2993, 0.3195],
],
Some(1),
Some(0),
@@ -2979,14 +2945,20 @@ fn test_argmax() -> Result<()> {
[1.3398, 0.2663, -0.2686, 0.2450],
[-0.7401, -0.8805, -0.3402, -1.1936],
[0.4907, -1.3948, -1.0691, -0.3132],
- [-1.6092, 0.5419, -0.2993, 0.3195]
+ [-1.6092, 0.5419, -0.2993, 0.3195],
],
Some(1),
None,
None,
&[[0i64], [2i64], [0i64], [1i64]],
)?;
- fn test(data: impl NdArray, axis: Option<i64>, keepdims: Option<i64>, select_last_index: Option<i64>, expected: impl NdArray) -> Result<()> {
+ fn test(
+ data: impl NdArray,
+ axis: Option<i64>,
+ keepdims: Option<i64>,
+ select_last_index: Option<i64>,
+ expected: impl NdArray,
+ ) -> Result<()> {
let att_axis = AttributeProto {
name: "axis".to_string(),
ref_attr_name: "axis".to_string(),
@@ -3106,11 +3078,7 @@ fn test_argmax() -> Result<()> {
fn test_leakyrelu() -> Result<()> {
// tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-80
// leakyrelu
- test(
- &[-1.0, 0.0, 1.0],
- Some(0.1),
- &[-0.1, 0.0, 1.0]
- )?;
+ test(&[-1.0, 0.0, 1.0], Some(0.1), &[-0.1, 0.0, 1.0])?;
fn test(data: impl NdArray, alpha: Option<f32>, expected: impl NdArray) -> Result<()> {
let att_alpha = AttributeProto {
name: "alpha".to_string(),
@@ -3168,7 +3136,11 @@ fn test_leakyrelu() -> Result<()> {
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let expected = Tensor::new(expected, &Device::Cpu)?;
- for both in z.to_vec1::<f64>()?.iter().zip(expected.to_vec1::<f64>()?.iter()) {
+ for both in z
+ .to_vec1::<f64>()?
+ .iter()
+ .zip(expected.to_vec1::<f64>()?.iter())
+ {
let (act, exp) = both;
assert!(f64::abs(act - exp) < f32::EPSILON.into());
}