summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/tests.rs
diff options
context:
space:
mode:
authorIvar Flakstad <69173633+ivarflakstad@users.noreply.github.com>2024-01-12 07:19:58 +0100
committerIvar Flakstad <69173633+ivarflakstad@users.noreply.github.com>2024-01-12 07:19:58 +0100
commite63bb8661beb4ea139f4e7f1d85f56907d918b2b (patch)
tree2326f731957d56667ba4d432a68f8b37a2b79830 /candle-metal-kernels/src/tests.rs
parent87efb5d8eb6a6c3f17acf326aadcb11ad6900306 (diff)
parent41915184bb3e530cc8184fdd8841c66df9285684 (diff)
downloadcandle-e63bb8661beb4ea139f4e7f1d85f56907d918b2b.tar.gz
candle-e63bb8661beb4ea139f4e7f1d85f56907d918b2b.tar.bz2
candle-e63bb8661beb4ea139f4e7f1d85f56907d918b2b.zip
Merge branch 'main' into ivarflakstad/metal-prng
Diffstat (limited to 'candle-metal-kernels/src/tests.rs')
-rw-r--r--candle-metal-kernels/src/tests.rs154
1 files changed, 143 insertions, 11 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index 067dece8..b15505f7 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -1,6 +1,6 @@
use super::*;
use half::{bf16, f16};
-use metal::{Device, MTLResourceOptions};
+use metal::{Buffer, Device, MTLResourceOptions};
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
@@ -248,6 +248,34 @@ fn binary_add_f32() {
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
}
+#[test]
+fn binary_ops_bf16() {
+ let lhs: Vec<bf16> = [1.1f32, 2.2, 3.3].into_iter().map(bf16::from_f32).collect();
+ let rhs: Vec<bf16> = [4.2f32, 5.5f32, 6.91f32]
+ .into_iter()
+ .map(bf16::from_f32)
+ .collect();
+
+ macro_rules! binary_op {
+ ($opname:ident, $opexpr:expr) => {{
+ let results = run_binary(&lhs, &rhs, binary::contiguous::$opname::BFLOAT);
+ let expected: Vec<bf16> = lhs
+ .iter()
+ .zip(rhs.iter())
+ .map(|(x, y): (&bf16, &bf16)| $opexpr(*x, *y))
+ .collect();
+ assert_eq!(results, expected);
+ }};
+ }
+
+ binary_op!(add, |x, y| x + y);
+ binary_op!(sub, |x, y| x - y);
+ binary_op!(mul, |x, y| x * y);
+ binary_op!(div, |x, y| x / y);
+ binary_op!(min, |x: bf16, y| x.min(y));
+ binary_op!(max, |x: bf16, y| x.max(y));
+}
+
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device();
let fence = device.new_fence();
@@ -296,6 +324,89 @@ fn cast_u32_f32() {
assert_eq!(results, vec![1.0f32; 10_000]);
}
+#[test]
+fn it_cast_bf16_u32() {
+ let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
+
+ let output: Vec<u32> = cast(&input, "cast_bf16_u32");
+ let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_bf16_f32() {
+ let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
+
+ let output: Vec<f32> = cast(&input, "cast_bf16_f32");
+ let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_u8_bf16() {
+ let input: Vec<u8> = (1..=3).map(|v| v as u8).collect();
+
+ let output: Vec<bf16> = cast(&input, "cast_u8_bf16");
+ let expected: Vec<bf16> = input
+ .iter()
+ .map(|v| bf16::from_f32(*v as f32))
+ .collect::<Vec<_>>();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_u32_bf16() {
+ let input: Vec<u32> = (1..=3).map(|v| v as u32).collect();
+
+ let output: Vec<bf16> = cast(&input, "cast_u32_bf16");
+ let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_f32_bf16() {
+ let input: Vec<f32> = (1..=3).map(|v| v as f32).collect();
+
+ let output: Vec<bf16> = cast(&input, "cast_f32_bf16");
+ let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_bf16_u8() {
+ let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
+
+ let output: Vec<u8> = cast(&input, "cast_bf16_u8");
+ let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_bf16_f16() {
+ let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
+
+ let output: Vec<f16> = cast(&input, "cast_bf16_f16");
+ let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_f16_bf16() {
+ let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect();
+
+ let output: Vec<bf16> = cast(&input, "cast_f16_bf16");
+ let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect();
+
+ assert_eq!(output, expected);
+}
+
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let device = device();
let fence = device.new_fence();
@@ -396,14 +507,14 @@ fn index_select() {
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
- let result = run_index_select(&embedding, &shape, &ids, dim);
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [2, 5];
let ids = [0u32, 1, 0];
let dim = 0;
- let result = run_index_select(&embedding, &shape, &ids, dim);
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
assert_eq!(
result,
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
@@ -419,7 +530,7 @@ fn index_select_f16() {
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
- let result = run_index_select(&embedding, &shape, &ids, dim);
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16");
assert_eq!(
approx_f16(result, 4),
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
@@ -427,12 +538,38 @@ fn index_select_f16() {
}
#[test]
+fn index_select_is_u32_bf16() {
+ let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
+ let shape = [5, 2];
+ let ids = [0u32, 4, 2];
+ let dim = 0;
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16");
+ assert_eq!(
+ approx_bf16(result, 4),
+ vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
+ );
+}
+
+#[test]
+fn index_select_is_u8_bf16() {
+ let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
+ let shape = [5, 2];
+ let ids = [0u8, 4, 2];
+ let dim = 0;
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16");
+ assert_eq!(
+ approx_bf16(result, 4),
+ vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
+ );
+}
+
+#[test]
fn index_select_dim1() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
let ids = [0u32, 1, 0];
let dim = 1;
- let result = run_index_select(&embedding, &shape, &ids, dim);
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
assert_eq!(
result,
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
@@ -444,6 +581,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
shape: &[usize],
ids: &[I],
dim: usize,
+ name: &'static str,
) -> Vec<T> {
let device = Device::system_default().expect("no device found");
@@ -457,12 +595,6 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let dst_el = ids.len() * left_size * right_size;
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
- let name = match core::mem::size_of::<T>() {
- 4 => "is_u32_f32",
- 2 => "is_u32_f16",
- _ => unimplemented!(),
- };
-
let fence = device.new_fence();
let kernels = Kernels::new(fence);
call_index_select(