summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorThomas Santerre <thomas@santerre.xyz>2024-03-17 03:09:43 -0400
committerGitHub <noreply@github.com>2024-03-17 08:09:43 +0100
commitdb8b24ae92419377283821ee0a65fb224a4f3c4d (patch)
tree95c6481d38f32b628192503ca9a250d033ebf38c /candle-metal-kernels
parent74bf6994b172f364c6e8bea2ac6e1bfbc6ca0c25 (diff)
downloadcandle-db8b24ae92419377283821ee0a65fb224a4f3c4d.tar.gz
candle-db8b24ae92419377283821ee0a65fb224a4f3c4d.tar.bz2
candle-db8b24ae92419377283821ee0a65fb224a4f3c4d.zip
Add support for index u8/i64 and input f16/bf16 scatter-add on metal (#1849)
* add support and tests for scatter add on metal * add support for all datatypes
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/indexing.metal13
-rw-r--r--candle-metal-kernels/src/tests.rs104
2 files changed, 115 insertions, 2 deletions
diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal
index 2a57bdbb..f6b81be0 100644
--- a/candle-metal-kernels/src/indexing.metal
+++ b/candle-metal-kernels/src/indexing.metal
@@ -167,11 +167,16 @@ kernel void NAME( \
INDEX_OP(is_u32_f32, uint, float)
INDEX_OP(is_u32_f16, uint, half)
+
GATHER_OP(gather_u32_f32, uint, float)
GATHER_OP(gather_u32_f16, uint, half)
-SCATTER_ADD_OP(sa_u32_f32, uint, float)
-SCATTER_ADD_OP(sa_u32_f16, uint, half)
+SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
+SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)
+SCATTER_ADD_OP(sa_i64_f32, int64_t, float)
+SCATTER_ADD_OP(sa_u32_f16, uint32_t, half)
+SCATTER_ADD_OP(sa_u8_f16, uint8_t, half)
+SCATTER_ADD_OP(sa_i64_f16, int64_t, half)
#if defined(__HAVE_BFLOAT__)
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
@@ -180,6 +185,10 @@ INDEX_OP(is_u8_bf16, uint8_t, bfloat)
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
+
+SCATTER_ADD_OP(sa_u32_bf16, uint32_t, bfloat)
+SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat)
+SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
#endif
INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index 459c8edb..b47fff6a 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -1066,3 +1066,107 @@ fn random() {
validate_random!(f16);
validate_random!(bf16);
}
+
+fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
+ input: &[T],
+ ids: &[I],
+ shape: &[usize],
+ dim: usize,
+ name: &'static str,
+) -> Vec<T> {
+ let device = device();
+ let kernels = Kernels::new();
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let options = MTLResourceOptions::StorageModeManaged;
+ let input_buffer = new_buffer(&device, input);
+ let ids_buffer = new_buffer(&device, ids);
+ let output = device.new_buffer(std::mem::size_of_val(input) as u64, options);
+ call_scatter_add(
+ &device,
+ command_buffer,
+ &kernels,
+ name,
+ shape,
+ shape,
+ dim,
+ &input_buffer,
+ 0,
+ &ids_buffer,
+ 0,
+ &output,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+ read_to_vec(&output, input.len())
+}
+
+#[test]
+fn scatter_add() {
+ let ids_u8 = [0u8, 0, 1, 0, 2, 2, 3, 3];
+ let ids_u32 = [0u32, 0, 1, 0, 2, 2, 3, 3];
+ let ids_i64 = [0i64, 0, 1, 0, 2, 2, 3, 3];
+
+ let input_f32 = [5.0f32, 1.0, 7.0, 2.0, 3.0, 2.0, 1.0, 3.0];
+ let input_f16 = input_f32
+ .iter()
+ .map(|v| f16::from_f32(*v))
+ .collect::<Vec<_>>();
+ let input_bf16 = input_f32
+ .iter()
+ .map(|v| bf16::from_f32(*v))
+ .collect::<Vec<_>>();
+
+ let output_dim1_f32 = vec![8.0, 7.0, 5.0, 4.0, 0.0, 0.0, 0.0, 0.0];
+ let output_dim1_f16 = output_dim1_f32
+ .iter()
+ .map(|v| f16::from_f32(*v))
+ .collect::<Vec<_>>();
+ let output_dim1_bf16 = output_dim1_f32
+ .iter()
+ .map(|v| bf16::from_f32(*v))
+ .collect::<Vec<_>>();
+
+ let output_dim2_f32 = vec![5.0, 3.0, 7.0, 0.0, 3.0, 2.0, 1.0, 3.0];
+ let output_dim2_f16 = output_dim2_f32
+ .iter()
+ .map(|v| f16::from_f32(*v))
+ .collect::<Vec<_>>();
+ let output_dim2_bf16 = output_dim2_f32
+ .iter()
+ .map(|v| bf16::from_f32(*v))
+ .collect::<Vec<_>>();
+
+ for (shape, output_f32, output_f16, output_bf16) in [
+ (vec![8], output_dim1_f32, output_dim1_f16, output_dim1_bf16),
+ (
+ vec![4, 2],
+ output_dim2_f32,
+ output_dim2_f16,
+ output_dim2_bf16,
+ ),
+ ] {
+ for results in [
+ run_scatter_add(&input_f32, &ids_u8, &shape, 0, "sa_u8_f32"),
+ run_scatter_add(&input_f32, &ids_u32, &shape, 0, "sa_u32_f32"),
+ run_scatter_add(&input_f32, &ids_i64, &shape, 0, "sa_i64_f32"),
+ ] {
+ assert_eq!(results, output_f32);
+ }
+ for results in [
+ run_scatter_add(&input_f16, &ids_u8, &shape, 0, "sa_u8_f16"),
+ run_scatter_add(&input_f16, &ids_u32, &shape, 0, "sa_u32_f16"),
+ run_scatter_add(&input_f16, &ids_i64, &shape, 0, "sa_i64_f16"),
+ ] {
+ assert_eq!(results, output_f16);
+ }
+ for results in [
+ run_scatter_add(&input_bf16, &ids_u8, &shape, 0, "sa_u8_bf16"),
+ run_scatter_add(&input_bf16, &ids_u32, &shape, 0, "sa_u32_bf16"),
+ run_scatter_add(&input_bf16, &ids_i64, &shape, 0, "sa_i64_bf16"),
+ ] {
+ assert_eq!(results, output_bf16);
+ }
+ }
+}