diff options
author | Thomas Santerre <thomas@santerre.xyz> | 2024-03-17 15:55:11 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-17 20:55:11 +0100 |
commit | e316cb699743b5d45ab4a1067057b8f6d8687a02 (patch) | |
tree | 10e72eb4e86f7818dedd50148467c9054737d783 /candle-metal-kernels | |
parent | ce9fbc368211815ef2dddff01575ca1f9d4eccd5 (diff) | |
download | candle-e316cb699743b5d45ab4a1067057b8f6d8687a02.tar.gz candle-e316cb699743b5d45ab4a1067057b8f6d8687a02.tar.bz2 candle-e316cb699743b5d45ab4a1067057b8f6d8687a02.zip |
add support for casting between all datatypes (#1860)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/cast.metal | 51 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 256 |
2 files changed, 211 insertions, 96 deletions
diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index 9aead139..2af3fdce 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -72,27 +72,60 @@ kernel void FN_NAME_STRIDED( \ output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \ } \ +// u32 CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t) +CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half) +#if __METAL_VERSION__ >= 220 +CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t) +#endif +#if defined(__HAVE_BFLOAT__) +CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat) +#endif + +// u8 CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float) -CAST(cast_f16_f32, cast_f16_f32_strided, half, float) -CAST(cast_f32_f16, cast_f32_f16_strided, float, half) - +CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half) #if __METAL_VERSION__ >= 220 CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t) -CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t) +#endif +#if defined(__HAVE_BFLOAT__) +CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat) +#endif + +// f16 +CAST(cast_f16_f32, cast_f16_f32_strided, half, float) +CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t) +CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t) +CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t) +#if defined(__HAVE_BFLOAT__) +CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) +#endif + +// i64 CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) +CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t) +CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t) +CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half) +#if defined(__HAVE_BFLOAT__) +CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float) #endif +// f32 +CAST(cast_f32_f16, cast_f32_f16_strided, float, half) +CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t) +CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t) +CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t) #if defined(__HAVE_BFLOAT__) -CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t) -CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) -CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat) -CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat) CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) +#endif +// bf16 +#if defined(__HAVE_BFLOAT__) +CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t) +CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t) +CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float) CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float) -CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) #endif
\ No newline at end of file diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index b47fff6a..b2f1d723 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -292,7 +292,7 @@ fn binary_ops_bf16() { binary_op!(max, |x: bf16, y| x.max(y)); } -fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { +fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { let device = device(); let kernels = Kernels::new(); let command_queue = device.new_command_queue(); @@ -319,107 +319,189 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { } #[test] -fn cast_u32_f32() { - let v = vec![1u32, 2, 3]; - let results = cast(&v, "cast_u32_f32"); - let expected: Vec<_> = v.iter().map(|&v| v as f32).collect(); - assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]); - assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]); - - let v = vec![1.0f32, 2.0, 3.0]; - let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect(); - let results: Vec<f32> = cast(&input, "cast_f16_f32"); - assert_eq!(results, vec![1.0f32, 2.0, 3.0]); - - let v = vec![1.0f32; 10_000]; - let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect(); - let results: Vec<f32> = cast(&input, "cast_f16_f32"); - assert_eq!(results.len(), 10_000); - assert_eq!(&results[..10], vec![1.0f32; 10]); - assert_eq!(results, vec![1.0f32; 10_000]); +fn cast_f32() { + let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect(); + + // f32 -> f16 + let results: Vec<half::f16> = run_cast(&v_f32, "cast_f32_f16"); + assert_eq!(results, v_f16); + + // f32 -> bf16 + let results: Vec<bf16> = run_cast(&v_f32, "cast_f32_bf16"); + assert_eq!(results, v_bf16); + + // f32 -> u32 + let results: Vec<u32> = run_cast(&v_f32, "cast_f32_u32"); + assert_eq!(results, v_u32); + + // f32 -> u8 + let results: Vec<u8> = run_cast(&v_f32, "cast_f32_u8"); + assert_eq!(results, v_u8); + + // f32 -> i64 + let results: Vec<i64> = run_cast(&v_f32, "cast_f32_i64"); + assert_eq!(results, v_i64); } #[test] -fn it_cast_bf16_u32() { - let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); - - let output: Vec<u32> = cast(&input, "cast_bf16_u32"); - let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect(); - - assert_eq!(output, expected); +fn cast_f16() { + let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect(); + + // f16 -> f32 + let results: Vec<f32> = run_cast(&v_f16, "cast_f16_f32"); + assert_eq!(results, v_f32); + + // f16 -> bf16 + let results: Vec<bf16> = run_cast(&v_f16, "cast_f16_bf16"); + assert_eq!(results, v_bf16); + + // f16 -> u32 + let results: Vec<u32> = run_cast(&v_f16, "cast_f16_u32"); + assert_eq!(results, v_u32); + + // f16 -> u8 + let results: Vec<u8> = run_cast(&v_f16, "cast_f16_u8"); + assert_eq!(results, v_u8); + + // f16 -> i64 + let results: Vec<i64> = run_cast(&v_f16, "cast_f16_i64"); + assert_eq!(results, v_i64); } #[test] -fn it_cast_bf16_f32() { - let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); - - let output: Vec<f32> = cast(&input, "cast_bf16_f32"); - let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect(); - - assert_eq!(output, expected); +fn cast_bf16() { + let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect(); + + // bf16 -> f32 + let results: Vec<f32> = run_cast(&v_bf16, "cast_bf16_f32"); + assert_eq!(results, v_f32); + + // bf16 -> f16 + let results: Vec<f16> = run_cast(&v_bf16, "cast_bf16_f16"); + assert_eq!(results, v_f16); + + // bf16 -> u32 + let results: Vec<u32> = run_cast(&v_bf16, "cast_bf16_u32"); + assert_eq!(results, v_u32); + + // bf16 -> u8 + let results: Vec<u8> = run_cast(&v_bf16, "cast_bf16_u8"); + assert_eq!(results, v_u8); + + // bf16 -> i64 + let results: Vec<i64> = run_cast(&v_bf16, "cast_bf16_i64"); + assert_eq!(results, v_i64); } #[test] -fn it_cast_u8_bf16() { - let input: Vec<u8> = (1..=3).map(|v| v as u8).collect(); - - let output: Vec<bf16> = cast(&input, "cast_u8_bf16"); - let expected: Vec<bf16> = input - .iter() - .map(|v| bf16::from_f32(*v as f32)) - .collect::<Vec<_>>(); - - assert_eq!(output, expected); +fn cast_u32() { + let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect(); + + // u32 -> f32 + let results: Vec<f32> = run_cast(&v_u32, "cast_u32_f32"); + assert_eq!(results, v_f32); + + // u32 -> f16 + let results: Vec<f16> = run_cast(&v_u32, "cast_u32_f16"); + assert_eq!(results, v_f16); + + // u32 -> bf16 + let results: Vec<bf16> = run_cast(&v_u32, "cast_u32_bf16"); + assert_eq!(results, v_bf16); + + // u32 -> u8 + let results: Vec<u8> = run_cast(&v_u32, "cast_u32_u8"); + assert_eq!(results, v_u8); + + // u32 -> i64 + let results: Vec<i64> = run_cast(&v_u32, "cast_u32_i64"); + assert_eq!(results, v_i64); } #[test] -fn it_cast_u32_bf16() { - let input: Vec<u32> = (1..=3).map(|v| v as u32).collect(); - - let output: Vec<bf16> = cast(&input, "cast_u32_bf16"); - let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect(); - - assert_eq!(output, expected); +fn cast_u8() { + let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect(); + + // u8 -> f32 + let results: Vec<f32> = run_cast(&v_u8, "cast_u8_f32"); + assert_eq!(results, v_f32); + + // u8 -> f16 + let results: Vec<f16> = run_cast(&v_u8, "cast_u8_f16"); + assert_eq!(results, v_f16); + + // u8 -> bf16 + let results: Vec<bf16> = run_cast(&v_u8, "cast_u8_bf16"); + assert_eq!(results, v_bf16); + + // u8 -> u32 + let results: Vec<u32> = run_cast(&v_u8, "cast_u8_u32"); + assert_eq!(results, v_u32); + + // u8 -> i64 + let results: Vec<i64> = run_cast(&v_u8, "cast_u8_i64"); + assert_eq!(results, v_i64); } #[test] -fn it_cast_f32_bf16() { - let input: Vec<f32> = (1..=3).map(|v| v as f32).collect(); - - let output: Vec<bf16> = cast(&input, "cast_f32_bf16"); - let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect(); - - assert_eq!(output, expected); -} - -#[test] -fn it_cast_bf16_u8() { - let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); - - let output: Vec<u8> = cast(&input, "cast_bf16_u8"); - let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect(); - - assert_eq!(output, expected); -} - -#[test] -fn it_cast_bf16_f16() { - let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); - - let output: Vec<f16> = cast(&input, "cast_bf16_f16"); - let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect(); - - assert_eq!(output, expected); -} - -#[test] -fn it_cast_f16_bf16() { - let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect(); - - let output: Vec<bf16> = cast(&input, "cast_f16_bf16"); - let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect(); - - assert_eq!(output, expected); +fn cast_i64() { + let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect(); + let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); + let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); + let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect(); + let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect(); + let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect(); + + // i64 -> f32 + let results: Vec<f32> = run_cast(&v_i64, "cast_i64_f32"); + assert_eq!(results, v_f32); + + // i64 -> f16 + let results: Vec<f16> = run_cast(&v_i64, "cast_i64_f16"); + assert_eq!(results, v_f16); + + // i64 -> bf16 + let results: Vec<bf16> = run_cast(&v_i64, "cast_i64_bf16"); + assert_eq!(results, v_bf16); + + // i64 -> u32 + let results: Vec<u32> = run_cast(&v_i64, "cast_i64_u32"); + assert_eq!(results, v_u32); + + // i64 -> u8 + let results: Vec<u8> = run_cast(&v_i64, "cast_i64_u8"); + assert_eq!(results, v_u8); } fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { |