summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLionel Touati <ltouati@gmail.com>2024-06-02 14:30:06 +0200
committerGitHub <noreply@github.com>2024-06-02 14:30:06 +0200
commit1ec3b2cc189fa6020018f2c8dad7b216b4512019 (patch)
treefc643da98f7649780798a2668279d39e3441c47f /candle-metal-kernels
parentf7773d498a58fc5678784bd4843011974e11f953 (diff)
downloadcandle-1ec3b2cc189fa6020018f2c8dad7b216b4512019.tar.gz
candle-1ec3b2cc189fa6020018f2c8dad7b216b4512019.tar.bz2
candle-1ec3b2cc189fa6020018f2c8dad7b216b4512019.zip
add where_cond f32 for metal (#2236)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/tests.rs21
1 files changed, 21 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index 960ae1df..8c38e74a 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -1023,6 +1023,27 @@ fn where_cond() {
);
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
}
+#[test]
+fn where_cond_u32_f32() {
+ let shape = vec![6];
+ let cond = vec![0u32, 1, 0, 0, 1, 1];
+ let cond_l = (vec![1], 0);
+ let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
+ let left_l = (vec![1], 0);
+ let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0];
+ let right_l = (vec![1], 0);
+ let results = run_where_cond(
+ &shape,
+ &cond,
+ cond_l,
+ &left_true,
+ left_l,
+ &right_false,
+ right_l,
+ "where_u32_f32",
+ );
+ assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
+}
fn run_gemm<T: Clone>(
(b, m, n, k): (usize, usize, usize, usize),