summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorIvar Flakstad <69173633+ivarflakstad@users.noreply.github.com>2024-01-12 07:19:58 +0100
committerIvar Flakstad <69173633+ivarflakstad@users.noreply.github.com>2024-01-12 07:19:58 +0100
commite63bb8661beb4ea139f4e7f1d85f56907d918b2b (patch)
tree2326f731957d56667ba4d432a68f8b37a2b79830 /candle-metal-kernels
parent87efb5d8eb6a6c3f17acf326aadcb11ad6900306 (diff)
parent41915184bb3e530cc8184fdd8841c66df9285684 (diff)
downloadcandle-e63bb8661beb4ea139f4e7f1d85f56907d918b2b.tar.gz
candle-e63bb8661beb4ea139f4e7f1d85f56907d918b2b.tar.bz2
candle-e63bb8661beb4ea139f4e7f1d85f56907d918b2b.zip
Merge branch 'main' into ivarflakstad/metal-prng
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/Cargo.toml9
-rw-r--r--candle-metal-kernels/src/affine.metal2
-rw-r--r--candle-metal-kernels/src/binary.metal2
-rw-r--r--candle-metal-kernels/src/cast.metal42
-rw-r--r--candle-metal-kernels/src/indexing.metal5
-rw-r--r--candle-metal-kernels/src/lib.rs4
-rw-r--r--candle-metal-kernels/src/reduce.metal2
-rw-r--r--candle-metal-kernels/src/tests.rs154
-rw-r--r--candle-metal-kernels/src/unary.metal10
9 files changed, 206 insertions, 24 deletions
diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml
index 441d2e88..187cb4fd 100644
--- a/candle-metal-kernels/Cargo.toml
+++ b/candle-metal-kernels/Cargo.toml
@@ -9,12 +9,17 @@ keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"
+
[dependencies]
-metal = { version = "0.27.0", features = ["mps"]}
+metal = { version = "0.27.0", features = ["mps"] }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"
[dev-dependencies]
-half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
+half = { version = "2.3.1", features = [
+ "num-traits",
+ "use-intrinsics",
+ "rand_distr",
+] }
rand = "0.8.5"
diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal
index 4166d811..3d8e7f0d 100644
--- a/candle-metal-kernels/src/affine.metal
+++ b/candle-metal-kernels/src/affine.metal
@@ -117,7 +117,7 @@ ELU(elu_f32, float)
ELU(elu_f16, half)
-#if __METAL_VERSION__ >= 310
+#if defined(__HAVE_BFLOAT__)
AFFINE(affine_bf16, bfloat);
POWF(powf_bf16, bfloat);
ELU(elu_bf16, bfloat);
diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal
index cdc8fef8..eb560f16 100644
--- a/candle-metal-kernels/src/binary.metal
+++ b/candle-metal-kernels/src/binary.metal
@@ -105,7 +105,7 @@ INT64_BINARY_OP_OUT(ge, x >= y)
INT64_BINARY_OP_OUT(gt, x > y)
#endif
-#if __METAL_VERSION__ >= 310
+#if defined(__HAVE_BFLOAT__)
BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub)
BFLOAT_BINARY_OP(x * y, mul)
diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal
index e9ab17b1..9aead139 100644
--- a/candle-metal-kernels/src/cast.metal
+++ b/candle-metal-kernels/src/cast.metal
@@ -28,7 +28,7 @@ kernel void FN_NAME( \
if (tid >= dim) { \
return; \
} \
- output[tid] = RIGHT_TYPENAME(input[tid]); \
+ output[tid] = static_cast<RIGHT_TYPENAME>(input[tid]); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@@ -42,7 +42,34 @@ kernel void FN_NAME_STRIDED( \
if (tid >= dim) { \
return; \
} \
- output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
+ output[tid] = static_cast<RIGHT_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)]); \
+} \
+
+#define CAST_THROUGH(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME, IR_TYPENAME) \
+kernel void FN_NAME( \
+ constant size_t &dim, \
+ device const LEFT_TYPENAME *input, \
+ device RIGHT_TYPENAME *output, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ if (tid >= dim) { \
+ return; \
+ } \
+ output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[tid])); \
+} \
+kernel void FN_NAME_STRIDED( \
+ constant size_t &dim, \
+ constant size_t &num_dims, \
+ constant size_t *dims, \
+ constant size_t *strides, \
+ device const LEFT_TYPENAME *input, \
+ device RIGHT_TYPENAME *output, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ if (tid >= dim) { \
+ return; \
+ } \
+ output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
} \
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
@@ -58,7 +85,14 @@ CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
#endif
-#if __METAL_VERSION__ >= 310
+#if defined(__HAVE_BFLOAT__)
+CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
+CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
+CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
-#endif
+
+CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
+CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
+CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
+#endif \ No newline at end of file
diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal
index 63357428..2a57bdbb 100644
--- a/candle-metal-kernels/src/indexing.metal
+++ b/candle-metal-kernels/src/indexing.metal
@@ -173,7 +173,10 @@ SCATTER_ADD_OP(sa_u32_f32, uint, float)
SCATTER_ADD_OP(sa_u32_f16, uint, half)
-#if __METAL_VERSION__ >= 310
+#if defined(__HAVE_BFLOAT__)
+INDEX_OP(is_u32_bf16, uint32_t, bfloat)
+INDEX_OP(is_u8_bf16, uint8_t, bfloat)
+
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 75f0286d..c427a690 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -178,8 +178,8 @@ macro_rules! ops{
pub mod unary {
ops!(
- cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh,
- recip
+ cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
+ tanh, recip
);
}
pub mod binary {
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal
index 83a56f0a..93dac662 100644
--- a/candle-metal-kernels/src/reduce.metal
+++ b/candle-metal-kernels/src/reduce.metal
@@ -295,7 +295,7 @@ ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
#endif
-#if __METAL_VERSION__ >= 310
+#if defined(__HAVE_BFLOAT__)
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index 067dece8..b15505f7 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -1,6 +1,6 @@
use super::*;
use half::{bf16, f16};
-use metal::{Device, MTLResourceOptions};
+use metal::{Buffer, Device, MTLResourceOptions};
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
@@ -248,6 +248,34 @@ fn binary_add_f32() {
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
}
+#[test]
+fn binary_ops_bf16() {
+ let lhs: Vec<bf16> = [1.1f32, 2.2, 3.3].into_iter().map(bf16::from_f32).collect();
+ let rhs: Vec<bf16> = [4.2f32, 5.5f32, 6.91f32]
+ .into_iter()
+ .map(bf16::from_f32)
+ .collect();
+
+ macro_rules! binary_op {
+ ($opname:ident, $opexpr:expr) => {{
+ let results = run_binary(&lhs, &rhs, binary::contiguous::$opname::BFLOAT);
+ let expected: Vec<bf16> = lhs
+ .iter()
+ .zip(rhs.iter())
+ .map(|(x, y): (&bf16, &bf16)| $opexpr(*x, *y))
+ .collect();
+ assert_eq!(results, expected);
+ }};
+ }
+
+ binary_op!(add, |x, y| x + y);
+ binary_op!(sub, |x, y| x - y);
+ binary_op!(mul, |x, y| x * y);
+ binary_op!(div, |x, y| x / y);
+ binary_op!(min, |x: bf16, y| x.min(y));
+ binary_op!(max, |x: bf16, y| x.max(y));
+}
+
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device();
let fence = device.new_fence();
@@ -296,6 +324,89 @@ fn cast_u32_f32() {
assert_eq!(results, vec![1.0f32; 10_000]);
}
+#[test]
+fn it_cast_bf16_u32() {
+ let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
+
+ let output: Vec<u32> = cast(&input, "cast_bf16_u32");
+ let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_bf16_f32() {
+ let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
+
+ let output: Vec<f32> = cast(&input, "cast_bf16_f32");
+ let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_u8_bf16() {
+ let input: Vec<u8> = (1..=3).map(|v| v as u8).collect();
+
+ let output: Vec<bf16> = cast(&input, "cast_u8_bf16");
+ let expected: Vec<bf16> = input
+ .iter()
+ .map(|v| bf16::from_f32(*v as f32))
+ .collect::<Vec<_>>();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_u32_bf16() {
+ let input: Vec<u32> = (1..=3).map(|v| v as u32).collect();
+
+ let output: Vec<bf16> = cast(&input, "cast_u32_bf16");
+ let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_f32_bf16() {
+ let input: Vec<f32> = (1..=3).map(|v| v as f32).collect();
+
+ let output: Vec<bf16> = cast(&input, "cast_f32_bf16");
+ let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_bf16_u8() {
+ let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
+
+ let output: Vec<u8> = cast(&input, "cast_bf16_u8");
+ let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_bf16_f16() {
+ let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
+
+ let output: Vec<f16> = cast(&input, "cast_bf16_f16");
+ let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_f16_bf16() {
+ let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect();
+
+ let output: Vec<bf16> = cast(&input, "cast_f16_bf16");
+ let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect();
+
+ assert_eq!(output, expected);
+}
+
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let device = device();
let fence = device.new_fence();
@@ -396,14 +507,14 @@ fn index_select() {
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
- let result = run_index_select(&embedding, &shape, &ids, dim);
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [2, 5];
let ids = [0u32, 1, 0];
let dim = 0;
- let result = run_index_select(&embedding, &shape, &ids, dim);
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
assert_eq!(
result,
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
@@ -419,7 +530,7 @@ fn index_select_f16() {
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
- let result = run_index_select(&embedding, &shape, &ids, dim);
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16");
assert_eq!(
approx_f16(result, 4),
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
@@ -427,12 +538,38 @@ fn index_select_f16() {
}
#[test]
+fn index_select_is_u32_bf16() {
+ let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
+ let shape = [5, 2];
+ let ids = [0u32, 4, 2];
+ let dim = 0;
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16");
+ assert_eq!(
+ approx_bf16(result, 4),
+ vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
+ );
+}
+
+#[test]
+fn index_select_is_u8_bf16() {
+ let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
+ let shape = [5, 2];
+ let ids = [0u8, 4, 2];
+ let dim = 0;
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16");
+ assert_eq!(
+ approx_bf16(result, 4),
+ vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
+ );
+}
+
+#[test]
fn index_select_dim1() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
let ids = [0u32, 1, 0];
let dim = 1;
- let result = run_index_select(&embedding, &shape, &ids, dim);
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
assert_eq!(
result,
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
@@ -444,6 +581,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
shape: &[usize],
ids: &[I],
dim: usize,
+ name: &'static str,
) -> Vec<T> {
let device = Device::system_default().expect("no device found");
@@ -457,12 +595,6 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let dst_el = ids.len() * left_size * right_size;
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
- let name = match core::mem::size_of::<T>() {
- 4 => "is_u32_f32",
- 2 => "is_u32_f16",
- _ => unimplemented!(),
- };
-
let fence = device.new_fence();
let kernels = Kernels::new(fence);
call_index_select(
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index 7fbb613d..dcf803d8 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -58,6 +58,12 @@ template <typename T> METAL_FUNC T gelu(T x) {
T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha);
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
}
+template <typename T> METAL_FUNC T relu(T in){
+ if (in < 0) {
+ return 0;
+ }
+ return in;
+}
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
@@ -110,6 +116,7 @@ UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY_OP(tanh)
UNARY_OP(recip)
+UNARY_OP(relu)
UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided)
@@ -120,7 +127,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided)
UNARY(id, int64_t, copy_i64, copy_i64_strided)
#endif
-#if __METAL_VERSION__ >= 310
+#if defined(__HAVE_BFLOAT__)
BFLOAT_UNARY_OP(cos)
BFLOAT_UNARY_OP(sin)
BFLOAT_UNARY_OP(sqr)
@@ -136,6 +143,7 @@ BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh)
BFLOAT_UNARY_OP(recip)
+BFLOAT_UNARY_OP(relu)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
#endif