diff options
author | Lionel Touati <ltouati@gmail.com> | 2024-06-02 14:30:06 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-02 14:30:06 +0200 |
commit | 1ec3b2cc189fa6020018f2c8dad7b216b4512019 (patch) | |
tree | fc643da98f7649780798a2668279d39e3441c47f /candle-metal-kernels | |
parent | f7773d498a58fc5678784bd4843011974e11f953 (diff) | |
download | candle-1ec3b2cc189fa6020018f2c8dad7b216b4512019.tar.gz candle-1ec3b2cc189fa6020018f2c8dad7b216b4512019.tar.bz2 candle-1ec3b2cc189fa6020018f2c8dad7b216b4512019.zip |
add where_cond f32 for metal (#2236)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 960ae1df..8c38e74a 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1023,6 +1023,27 @@ fn where_cond() { ); assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); } +#[test] +fn where_cond_u32_f32() { + let shape = vec![6]; + let cond = vec![0u32, 1, 0, 0, 1, 1]; + let cond_l = (vec![1], 0); + let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let left_l = (vec![1], 0); + let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0]; + let right_l = (vec![1], 0); + let results = run_where_cond( + &shape, + &cond, + cond_l, + &left_true, + left_l, + &right_false, + right_l, + "where_u32_f32", + ); + assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); +} fn run_gemm<T: Clone>( (b, m, n, k): (usize, usize, usize, usize), |