diff options
-rw-r--r-- | candle-onnx/src/eval.rs | 33 | ||||
-rw-r--r-- | candle-onnx/tests/ops.rs | 104 |
2 files changed, 64 insertions, 73 deletions
diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 31eb62cd..e72002e6 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -1032,11 +1032,18 @@ pub fn simple_eval( let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0); let rank_i64: i64 = input.rank().try_into().unwrap(); if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 { - bail!("axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]", axis_i64, -rank_i64, rank_i64-1) + bail!( + "axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]", + axis_i64, + -rank_i64, + rank_i64 - 1 + ) } let axis = input.normalize_axis(axis_i64)?; let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1); - let select_last_index: i64 = get_attr_opt(node, "select_last_index")?.copied().unwrap_or(0); + let select_last_index: i64 = get_attr_opt(node, "select_last_index")? + .copied() + .unwrap_or(0); if select_last_index == 1 { bail!("select_last_index for ArgMin is currently not supported") } @@ -1044,7 +1051,8 @@ pub fn simple_eval( input.argmin_keepdim(axis)? } else { input.argmin(axis)? - }.to_dtype(DType::I64)?; + } + .to_dtype(DType::I64)?; values.insert(node.output[0].clone(), output); } "ArgMax" => { @@ -1052,11 +1060,18 @@ pub fn simple_eval( let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0); let rank_i64: i64 = input.rank().try_into().unwrap(); if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 { - bail!("axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]", axis_i64, -rank_i64, rank_i64-1) + bail!( + "axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]", + axis_i64, + -rank_i64, + rank_i64 - 1 + ) } let axis = input.normalize_axis(axis_i64)?; let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1); - let select_last_index: i64 = get_attr_opt(node, "select_last_index")?.copied().unwrap_or(0); + let select_last_index: i64 = get_attr_opt(node, "select_last_index")? + .copied() + .unwrap_or(0); if select_last_index == 1 { bail!("select_last_index for ArgMin is currently not supported") } @@ -1064,7 +1079,8 @@ pub fn simple_eval( input.argmax_keepdim(axis)? } else { input.argmax(axis)? - }.to_dtype(DType::I64)?; + } + .to_dtype(DType::I64)?; values.insert(node.output[0].clone(), output); } "LeakyRelu" => { @@ -1072,7 +1088,10 @@ pub fn simple_eval( let dt = input.dtype(); match dt { DType::U8 | DType::U32 | DType::I64 => { - bail!("unsupported dtype {}, only float types are allowed for LeakyRelu", dt.as_str()) + bail!( + "unsupported dtype {}, only float types are allowed for LeakyRelu", + dt.as_str() + ) } DType::BF16 | DType::F16 | DType::F32 | DType::F64 => {} } diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index ffafd7a7..2e60d22c 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -2715,51 +2715,31 @@ fn test_argmin() -> Result<()> { // tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-7 // default_axes_keepdims test( - &[ - [2u32, 1u32], - [3u32, 10u32] - ], + &[[2u32, 1u32], [3u32, 10u32]], None, Some(1), None, - &[ - [0i64, 0i64], - ], + &[[0i64, 0i64]], )?; // keepdims test( - &[ - [2u32, 1u32], - [3u32, 10u32] - ], + &[[2u32, 1u32], [3u32, 10u32]], Some(1), Some(1), None, - &[ - [1i64], - [0i64] - ], + &[[1i64], [0i64]], )?; // // negative_axis_keepdims test( - &[ - [2u32, 1u32], - [3u32, 10u32] - ], + &[[2u32, 1u32], [3u32, 10u32]], Some(-1), Some(1), None, - &[ - [1i64], - [0i64] - ], + &[[1i64], [0i64]], )?; // no_keepdims test( - &[ - [2u32, 1u32], - [3u32, 10u32] - ], + &[[2u32, 1u32], [3u32, 10u32]], None, Some(0), None, @@ -2771,7 +2751,7 @@ fn test_argmin() -> Result<()> { [0.1139, 0.2254, -0.1381, 0.3687], [1.0100, -1.1975, -0.0102, -0.4732], [-0.9240, 0.1207, -0.7506, -1.0213], - [1.7809, -1.2960, 0.9384, 0.1438] + [1.7809, -1.2960, 0.9384, 0.1438], ], Some(1), Some(0), @@ -2783,14 +2763,20 @@ fn test_argmin() -> Result<()> { [0.1139, 0.2254, -0.1381, 0.3687], [1.0100, -1.1975, -0.0102, -0.4732], [-0.9240, 0.1207, -0.7506, -1.0213], - [1.7809, -1.2960, 0.9384, 0.1438] + [1.7809, -1.2960, 0.9384, 0.1438], ], Some(1), None, None, &[[2i64], [1i64], [3i64], [1i64]], )?; - fn test(data: impl NdArray, axis: Option<i64>, keepdims: Option<i64>, select_last_index: Option<i64>, expected: impl NdArray) -> Result<()> { + fn test( + data: impl NdArray, + axis: Option<i64>, + keepdims: Option<i64>, + select_last_index: Option<i64>, + expected: impl NdArray, + ) -> Result<()> { let att_axis = AttributeProto { name: "axis".to_string(), ref_attr_name: "axis".to_string(), @@ -2911,51 +2897,31 @@ fn test_argmax() -> Result<()> { // tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-6 // default_axes_keepdims test( - &[ - [2u32, 1u32], - [3u32, 10u32] - ], + &[[2u32, 1u32], [3u32, 10u32]], None, Some(1), None, - &[ - [1i64, 1i64], - ], + &[[1i64, 1i64]], )?; // keepdims test( - &[ - [2u32, 1u32], - [3u32, 10u32] - ], + &[[2u32, 1u32], [3u32, 10u32]], Some(1), Some(1), None, - &[ - [0i64], - [1i64] - ], + &[[0i64], [1i64]], )?; // // negative_axis_keepdims test( - &[ - [2u32, 1u32], - [3u32, 10u32] - ], + &[[2u32, 1u32], [3u32, 10u32]], Some(-1), Some(1), None, - &[ - [0i64], - [1i64] - ], + &[[0i64], [1i64]], )?; // no_keepdims test( - &[ - [2u32, 1u32], - [3u32, 10u32] - ], + &[[2u32, 1u32], [3u32, 10u32]], None, Some(0), None, @@ -2967,7 +2933,7 @@ fn test_argmax() -> Result<()> { [1.3398, 0.2663, -0.2686, 0.2450], [-0.7401, -0.8805, -0.3402, -1.1936], [0.4907, -1.3948, -1.0691, -0.3132], - [-1.6092, 0.5419, -0.2993, 0.3195] + [-1.6092, 0.5419, -0.2993, 0.3195], ], Some(1), Some(0), @@ -2979,14 +2945,20 @@ fn test_argmax() -> Result<()> { [1.3398, 0.2663, -0.2686, 0.2450], [-0.7401, -0.8805, -0.3402, -1.1936], [0.4907, -1.3948, -1.0691, -0.3132], - [-1.6092, 0.5419, -0.2993, 0.3195] + [-1.6092, 0.5419, -0.2993, 0.3195], ], Some(1), None, None, &[[0i64], [2i64], [0i64], [1i64]], )?; - fn test(data: impl NdArray, axis: Option<i64>, keepdims: Option<i64>, select_last_index: Option<i64>, expected: impl NdArray) -> Result<()> { + fn test( + data: impl NdArray, + axis: Option<i64>, + keepdims: Option<i64>, + select_last_index: Option<i64>, + expected: impl NdArray, + ) -> Result<()> { let att_axis = AttributeProto { name: "axis".to_string(), ref_attr_name: "axis".to_string(), @@ -3106,11 +3078,7 @@ fn test_argmax() -> Result<()> { fn test_leakyrelu() -> Result<()> { // tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-80 // leakyrelu - test( - &[-1.0, 0.0, 1.0], - Some(0.1), - &[-0.1, 0.0, 1.0] - )?; + test(&[-1.0, 0.0, 1.0], Some(0.1), &[-0.1, 0.0, 1.0])?; fn test(data: impl NdArray, alpha: Option<f32>, expected: impl NdArray) -> Result<()> { let att_alpha = AttributeProto { name: "alpha".to_string(), @@ -3168,7 +3136,11 @@ fn test_leakyrelu() -> Result<()> { let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let expected = Tensor::new(expected, &Device::Cpu)?; - for both in z.to_vec1::<f64>()?.iter().zip(expected.to_vec1::<f64>()?.iter()) { + for both in z + .to_vec1::<f64>()? + .iter() + .zip(expected.to_vec1::<f64>()?.iter()) + { let (act, exp) = both; assert!(f64::abs(act - exp) < f32::EPSILON.into()); } |