summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorThomas Santerre <thomas@santerre.xyz>2024-03-17 15:55:11 -0400
committerGitHub <noreply@github.com>2024-03-17 20:55:11 +0100
commite316cb699743b5d45ab4a1067057b8f6d8687a02 (patch)
tree10e72eb4e86f7818dedd50148467c9054737d783 /candle-metal-kernels
parentce9fbc368211815ef2dddff01575ca1f9d4eccd5 (diff)
downloadcandle-e316cb699743b5d45ab4a1067057b8f6d8687a02.tar.gz
candle-e316cb699743b5d45ab4a1067057b8f6d8687a02.tar.bz2
candle-e316cb699743b5d45ab4a1067057b8f6d8687a02.zip
add support for casting between all datatypes (#1860)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/cast.metal51
-rw-r--r--candle-metal-kernels/src/tests.rs256
2 files changed, 211 insertions, 96 deletions
diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal
index 9aead139..2af3fdce 100644
--- a/candle-metal-kernels/src/cast.metal
+++ b/candle-metal-kernels/src/cast.metal
@@ -72,27 +72,60 @@ kernel void FN_NAME_STRIDED( \
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
} \
+// u32
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
+CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half)
+#if __METAL_VERSION__ >= 220
+CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
+#endif
+#if defined(__HAVE_BFLOAT__)
+CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
+#endif
+
+// u8
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
-CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
-CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
-
+CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half)
#if __METAL_VERSION__ >= 220
CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
-CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
+#endif
+#if defined(__HAVE_BFLOAT__)
+CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
+#endif
+
+// f16
+CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
+CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t)
+CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t)
+CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t)
+#if defined(__HAVE_BFLOAT__)
+CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
+#endif
+
+// i64
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
+CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t)
+CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t)
+CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half)
+#if defined(__HAVE_BFLOAT__)
+CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float)
#endif
+// f32
+CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
+CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t)
+CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t)
+CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t)
#if defined(__HAVE_BFLOAT__)
-CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
-CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
-CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
-CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
+#endif
+// bf16
+#if defined(__HAVE_BFLOAT__)
+CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
+CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t)
+CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
-CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
#endif \ No newline at end of file
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index b47fff6a..b2f1d723 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -292,7 +292,7 @@ fn binary_ops_bf16() {
binary_op!(max, |x: bf16, y| x.max(y));
}
-fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
+fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
@@ -319,107 +319,189 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
}
#[test]
-fn cast_u32_f32() {
- let v = vec![1u32, 2, 3];
- let results = cast(&v, "cast_u32_f32");
- let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
- assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
- assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
-
- let v = vec![1.0f32, 2.0, 3.0];
- let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
- let results: Vec<f32> = cast(&input, "cast_f16_f32");
- assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
-
- let v = vec![1.0f32; 10_000];
- let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
- let results: Vec<f32> = cast(&input, "cast_f16_f32");
- assert_eq!(results.len(), 10_000);
- assert_eq!(&results[..10], vec![1.0f32; 10]);
- assert_eq!(results, vec![1.0f32; 10_000]);
+fn cast_f32() {
+ let v_f64 = vec![1.0f64, 2.0, 3.0];
+ let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
+ let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
+ let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
+ let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
+ let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
+ let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
+
+ // f32 -> f16
+ let results: Vec<half::f16> = run_cast(&v_f32, "cast_f32_f16");
+ assert_eq!(results, v_f16);
+
+ // f32 -> bf16
+ let results: Vec<bf16> = run_cast(&v_f32, "cast_f32_bf16");
+ assert_eq!(results, v_bf16);
+
+ // f32 -> u32
+ let results: Vec<u32> = run_cast(&v_f32, "cast_f32_u32");
+ assert_eq!(results, v_u32);
+
+ // f32 -> u8
+ let results: Vec<u8> = run_cast(&v_f32, "cast_f32_u8");
+ assert_eq!(results, v_u8);
+
+ // f32 -> i64
+ let results: Vec<i64> = run_cast(&v_f32, "cast_f32_i64");
+ assert_eq!(results, v_i64);
}
#[test]
-fn it_cast_bf16_u32() {
- let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
-
- let output: Vec<u32> = cast(&input, "cast_bf16_u32");
- let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect();
-
- assert_eq!(output, expected);
+fn cast_f16() {
+ let v_f64 = vec![1.0f64, 2.0, 3.0];
+ let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
+ let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
+ let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
+ let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
+ let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
+ let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
+
+ // f16 -> f32
+ let results: Vec<f32> = run_cast(&v_f16, "cast_f16_f32");
+ assert_eq!(results, v_f32);
+
+ // f16 -> bf16
+ let results: Vec<bf16> = run_cast(&v_f16, "cast_f16_bf16");
+ assert_eq!(results, v_bf16);
+
+ // f16 -> u32
+ let results: Vec<u32> = run_cast(&v_f16, "cast_f16_u32");
+ assert_eq!(results, v_u32);
+
+ // f16 -> u8
+ let results: Vec<u8> = run_cast(&v_f16, "cast_f16_u8");
+ assert_eq!(results, v_u8);
+
+ // f16 -> i64
+ let results: Vec<i64> = run_cast(&v_f16, "cast_f16_i64");
+ assert_eq!(results, v_i64);
}
#[test]
-fn it_cast_bf16_f32() {
- let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
-
- let output: Vec<f32> = cast(&input, "cast_bf16_f32");
- let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect();
-
- assert_eq!(output, expected);
+fn cast_bf16() {
+ let v_f64 = vec![1.0f64, 2.0, 3.0];
+ let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
+ let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
+ let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
+ let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
+ let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
+ let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
+
+ // bf16 -> f32
+ let results: Vec<f32> = run_cast(&v_bf16, "cast_bf16_f32");
+ assert_eq!(results, v_f32);
+
+ // bf16 -> f16
+ let results: Vec<f16> = run_cast(&v_bf16, "cast_bf16_f16");
+ assert_eq!(results, v_f16);
+
+ // bf16 -> u32
+ let results: Vec<u32> = run_cast(&v_bf16, "cast_bf16_u32");
+ assert_eq!(results, v_u32);
+
+ // bf16 -> u8
+ let results: Vec<u8> = run_cast(&v_bf16, "cast_bf16_u8");
+ assert_eq!(results, v_u8);
+
+ // bf16 -> i64
+ let results: Vec<i64> = run_cast(&v_bf16, "cast_bf16_i64");
+ assert_eq!(results, v_i64);
}
#[test]
-fn it_cast_u8_bf16() {
- let input: Vec<u8> = (1..=3).map(|v| v as u8).collect();
-
- let output: Vec<bf16> = cast(&input, "cast_u8_bf16");
- let expected: Vec<bf16> = input
- .iter()
- .map(|v| bf16::from_f32(*v as f32))
- .collect::<Vec<_>>();
-
- assert_eq!(output, expected);
+fn cast_u32() {
+ let v_f64 = vec![1.0f64, 2.0, 3.0];
+ let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
+ let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
+ let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
+ let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
+ let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
+ let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
+
+ // u32 -> f32
+ let results: Vec<f32> = run_cast(&v_u32, "cast_u32_f32");
+ assert_eq!(results, v_f32);
+
+ // u32 -> f16
+ let results: Vec<f16> = run_cast(&v_u32, "cast_u32_f16");
+ assert_eq!(results, v_f16);
+
+ // u32 -> bf16
+ let results: Vec<bf16> = run_cast(&v_u32, "cast_u32_bf16");
+ assert_eq!(results, v_bf16);
+
+ // u32 -> u8
+ let results: Vec<u8> = run_cast(&v_u32, "cast_u32_u8");
+ assert_eq!(results, v_u8);
+
+ // u32 -> i64
+ let results: Vec<i64> = run_cast(&v_u32, "cast_u32_i64");
+ assert_eq!(results, v_i64);
}
#[test]
-fn it_cast_u32_bf16() {
- let input: Vec<u32> = (1..=3).map(|v| v as u32).collect();
-
- let output: Vec<bf16> = cast(&input, "cast_u32_bf16");
- let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
-
- assert_eq!(output, expected);
+fn cast_u8() {
+ let v_f64 = vec![1.0f64, 2.0, 3.0];
+ let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
+ let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
+ let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
+ let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
+ let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
+ let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
+
+ // u8 -> f32
+ let results: Vec<f32> = run_cast(&v_u8, "cast_u8_f32");
+ assert_eq!(results, v_f32);
+
+ // u8 -> f16
+ let results: Vec<f16> = run_cast(&v_u8, "cast_u8_f16");
+ assert_eq!(results, v_f16);
+
+ // u8 -> bf16
+ let results: Vec<bf16> = run_cast(&v_u8, "cast_u8_bf16");
+ assert_eq!(results, v_bf16);
+
+ // u8 -> u32
+ let results: Vec<u32> = run_cast(&v_u8, "cast_u8_u32");
+ assert_eq!(results, v_u32);
+
+ // u8 -> i64
+ let results: Vec<i64> = run_cast(&v_u8, "cast_u8_i64");
+ assert_eq!(results, v_i64);
}
#[test]
-fn it_cast_f32_bf16() {
- let input: Vec<f32> = (1..=3).map(|v| v as f32).collect();
-
- let output: Vec<bf16> = cast(&input, "cast_f32_bf16");
- let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
-
- assert_eq!(output, expected);
-}
-
-#[test]
-fn it_cast_bf16_u8() {
- let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
-
- let output: Vec<u8> = cast(&input, "cast_bf16_u8");
- let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect();
-
- assert_eq!(output, expected);
-}
-
-#[test]
-fn it_cast_bf16_f16() {
- let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
-
- let output: Vec<f16> = cast(&input, "cast_bf16_f16");
- let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect();
-
- assert_eq!(output, expected);
-}
-
-#[test]
-fn it_cast_f16_bf16() {
- let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect();
-
- let output: Vec<bf16> = cast(&input, "cast_f16_bf16");
- let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect();
-
- assert_eq!(output, expected);
+fn cast_i64() {
+ let v_f64 = vec![1.0f64, 2.0, 3.0];
+ let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
+ let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
+ let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
+ let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
+ let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
+ let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
+
+ // i64 -> f32
+ let results: Vec<f32> = run_cast(&v_i64, "cast_i64_f32");
+ assert_eq!(results, v_f32);
+
+ // i64 -> f16
+ let results: Vec<f16> = run_cast(&v_i64, "cast_i64_f16");
+ assert_eq!(results, v_f16);
+
+ // i64 -> bf16
+ let results: Vec<bf16> = run_cast(&v_i64, "cast_i64_bf16");
+ assert_eq!(results, v_bf16);
+
+ // i64 -> u32
+ let results: Vec<u32> = run_cast(&v_i64, "cast_i64_u32");
+ assert_eq!(results, v_u32);
+
+ // i64 -> u8
+ let results: Vec<u8> = run_cast(&v_i64, "cast_i64_u8");
+ assert_eq!(results, v_u8);
}
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {