diff options
author | Thomas Santerre <thomas@santerre.xyz> | 2024-03-17 03:09:43 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-17 08:09:43 +0100 |
commit | db8b24ae92419377283821ee0a65fb224a4f3c4d (patch) | |
tree | 95c6481d38f32b628192503ca9a250d033ebf38c /candle-metal-kernels | |
parent | 74bf6994b172f364c6e8bea2ac6e1bfbc6ca0c25 (diff) | |
download | candle-db8b24ae92419377283821ee0a65fb224a4f3c4d.tar.gz candle-db8b24ae92419377283821ee0a65fb224a4f3c4d.tar.bz2 candle-db8b24ae92419377283821ee0a65fb224a4f3c4d.zip |
Add support for index u8/i64 and input f16/bf16 scatter-add on metal (#1849)
* add support and tests for scatter add on metal
* add support for all datatypes
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/indexing.metal | 13 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 104 |
2 files changed, 115 insertions, 2 deletions
diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 2a57bdbb..f6b81be0 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -167,11 +167,16 @@ kernel void NAME( \ INDEX_OP(is_u32_f32, uint, float) INDEX_OP(is_u32_f16, uint, half) + GATHER_OP(gather_u32_f32, uint, float) GATHER_OP(gather_u32_f16, uint, half) -SCATTER_ADD_OP(sa_u32_f32, uint, float) -SCATTER_ADD_OP(sa_u32_f16, uint, half) +SCATTER_ADD_OP(sa_u32_f32, uint32_t, float) +SCATTER_ADD_OP(sa_u8_f32, uint8_t, float) +SCATTER_ADD_OP(sa_i64_f32, int64_t, float) +SCATTER_ADD_OP(sa_u32_f16, uint32_t, half) +SCATTER_ADD_OP(sa_u8_f16, uint8_t, half) +SCATTER_ADD_OP(sa_i64_f16, int64_t, half) #if defined(__HAVE_BFLOAT__) INDEX_OP(is_u32_bf16, uint32_t, bfloat) @@ -180,6 +185,10 @@ INDEX_OP(is_u8_bf16, uint8_t, bfloat) INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat) + +SCATTER_ADD_OP(sa_u32_bf16, uint32_t, bfloat) +SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat) +SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat) #endif INDEX_ADD_OP(ia_u32_f16, uint32_t, half) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 459c8edb..b47fff6a 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1066,3 +1066,107 @@ fn random() { validate_random!(f16); validate_random!(bf16); } + +fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>( + input: &[T], + ids: &[I], + shape: &[usize], + dim: usize, + name: &'static str, +) -> Vec<T> { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let input_buffer = new_buffer(&device, input); + let ids_buffer = new_buffer(&device, ids); + let output = device.new_buffer(std::mem::size_of_val(input) as u64, options); + call_scatter_add( + &device, + command_buffer, + &kernels, + name, + shape, + shape, + dim, + &input_buffer, + 0, + &ids_buffer, + 0, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + read_to_vec(&output, input.len()) +} + +#[test] +fn scatter_add() { + let ids_u8 = [0u8, 0, 1, 0, 2, 2, 3, 3]; + let ids_u32 = [0u32, 0, 1, 0, 2, 2, 3, 3]; + let ids_i64 = [0i64, 0, 1, 0, 2, 2, 3, 3]; + + let input_f32 = [5.0f32, 1.0, 7.0, 2.0, 3.0, 2.0, 1.0, 3.0]; + let input_f16 = input_f32 + .iter() + .map(|v| f16::from_f32(*v)) + .collect::<Vec<_>>(); + let input_bf16 = input_f32 + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::<Vec<_>>(); + + let output_dim1_f32 = vec![8.0, 7.0, 5.0, 4.0, 0.0, 0.0, 0.0, 0.0]; + let output_dim1_f16 = output_dim1_f32 + .iter() + .map(|v| f16::from_f32(*v)) + .collect::<Vec<_>>(); + let output_dim1_bf16 = output_dim1_f32 + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::<Vec<_>>(); + + let output_dim2_f32 = vec![5.0, 3.0, 7.0, 0.0, 3.0, 2.0, 1.0, 3.0]; + let output_dim2_f16 = output_dim2_f32 + .iter() + .map(|v| f16::from_f32(*v)) + .collect::<Vec<_>>(); + let output_dim2_bf16 = output_dim2_f32 + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::<Vec<_>>(); + + for (shape, output_f32, output_f16, output_bf16) in [ + (vec![8], output_dim1_f32, output_dim1_f16, output_dim1_bf16), + ( + vec![4, 2], + output_dim2_f32, + output_dim2_f16, + output_dim2_bf16, + ), + ] { + for results in [ + run_scatter_add(&input_f32, &ids_u8, &shape, 0, "sa_u8_f32"), + run_scatter_add(&input_f32, &ids_u32, &shape, 0, "sa_u32_f32"), + run_scatter_add(&input_f32, &ids_i64, &shape, 0, "sa_i64_f32"), + ] { + assert_eq!(results, output_f32); + } + for results in [ + run_scatter_add(&input_f16, &ids_u8, &shape, 0, "sa_u8_f16"), + run_scatter_add(&input_f16, &ids_u32, &shape, 0, "sa_u32_f16"), + run_scatter_add(&input_f16, &ids_i64, &shape, 0, "sa_i64_f16"), + ] { + assert_eq!(results, output_f16); + } + for results in [ + run_scatter_add(&input_bf16, &ids_u8, &shape, 0, "sa_u8_bf16"), + run_scatter_add(&input_bf16, &ids_u32, &shape, 0, "sa_u32_bf16"), + run_scatter_add(&input_bf16, &ids_i64, &shape, 0, "sa_i64_bf16"), + ] { + assert_eq!(results, output_bf16); + } + } +} |