summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/metal_backend.rs22
-rw-r--r--candle-metal-kernels/src/indexing.metal38
-rw-r--r--candle-metal-kernels/src/tests.rs116
3 files changed, 160 insertions, 16 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index a6513b1c..3bee7657 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -1242,9 +1242,29 @@ impl BackendStorage for MetalStorage {
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
};
let name = match (ids.dtype, self.dtype) {
+ (DType::I64, DType::BF16) => "ia_i64_bf16",
+ (DType::I64, DType::F16) => "ia_i64_f16",
+ (DType::I64, DType::F32) => "ia_i64_f32",
+ (DType::I64, DType::I64) => "ia_i64_i64",
+ (DType::I64, DType::U32) => "ia_i64_u32",
+ (DType::I64, DType::U8) => "ia_i64_u8",
+
+ (DType::U32, DType::BF16) => "ia_u32_bf16",
+ (DType::U32, DType::F16) => "ia_u32_f16",
(DType::U32, DType::F32) => "ia_u32_f32",
+ (DType::U32, DType::I64) => "ia_u32_i64",
+ (DType::U32, DType::U32) => "ia_u32_u32",
+ (DType::U32, DType::U8) => "ia_u32_u8",
+
+ (DType::U8, DType::BF16) => "ia_u8_bf16",
+ (DType::U8, DType::F16) => "ia_u8_f16",
+ (DType::U8, DType::F32) => "ia_u8_f32",
+ (DType::U8, DType::I64) => "ia_u8_i64",
+ (DType::U8, DType::U32) => "ia_u8_u32",
+ (DType::U8, DType::U8) => "ia_u8_u8",
+
_ => Err(MetalError::UnexpectedDType {
- msg: "index-add ids should be u32",
+ msg: "index-add ids should be u8/u32/i64",
expected: DType::U32,
got: ids.dtype(),
})?,
diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal
index f6b81be0..65491759 100644
--- a/candle-metal-kernels/src/indexing.metal
+++ b/candle-metal-kernels/src/indexing.metal
@@ -167,6 +167,10 @@ kernel void NAME( \
INDEX_OP(is_u32_f32, uint, float)
INDEX_OP(is_u32_f16, uint, half)
+#if defined(__HAVE_BFLOAT__)
+INDEX_OP(is_u32_bf16, uint32_t, bfloat)
+INDEX_OP(is_u8_bf16, uint8_t, bfloat)
+#endif
GATHER_OP(gather_u32_f32, uint, float)
GATHER_OP(gather_u32_f16, uint, half)
@@ -177,34 +181,38 @@ SCATTER_ADD_OP(sa_i64_f32, int64_t, float)
SCATTER_ADD_OP(sa_u32_f16, uint32_t, half)
SCATTER_ADD_OP(sa_u8_f16, uint8_t, half)
SCATTER_ADD_OP(sa_i64_f16, int64_t, half)
-
#if defined(__HAVE_BFLOAT__)
-INDEX_OP(is_u32_bf16, uint32_t, bfloat)
-INDEX_OP(is_u8_bf16, uint8_t, bfloat)
-
-INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
-INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
-INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
-
SCATTER_ADD_OP(sa_u32_bf16, uint32_t, bfloat)
SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat)
SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
#endif
-INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
-INDEX_ADD_OP(ia_u8_f16, uint8_t, half)
-
+// i64
+INDEX_ADD_OP(ia_i64_f16, int64_t, half)
INDEX_ADD_OP(ia_i64_f32, int64_t, float)
-INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t)
INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t)
INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t)
+INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t)
+#if defined(__HAVE_BFLOAT__)
+INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
+#endif
+// u32
+INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
INDEX_ADD_OP(ia_u32_f32, uint32_t, float)
-INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t)
INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t)
INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t)
+INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t)
+#if defined(__HAVE_BFLOAT__)
+INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
+#endif
+// u8
+INDEX_ADD_OP(ia_u8_f16, uint8_t, half)
INDEX_ADD_OP(ia_u8_f32, uint8_t, float)
-INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
-INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t)
+INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
+INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
+#if defined(__HAVE_BFLOAT__)
+INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
+#endif \ No newline at end of file
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index b2f1d723..a34882d3 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -1252,3 +1252,119 @@ fn scatter_add() {
}
}
}
+
+fn run_index_add<T: Clone, I: Clone + std::fmt::Debug>(
+ left: &[T],
+ right: &[T],
+ indices: &[I],
+ shape: &[usize],
+ dim: usize,
+ name: &'static str,
+) -> Vec<T> {
+ let device = device();
+ let kernels = Kernels::new();
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let input_buffer = new_buffer(&device, right);
+ let output = new_buffer(&device, left);
+ let indices_buffer = new_buffer(&device, indices);
+ call_index_add(
+ &device,
+ command_buffer,
+ &kernels,
+ name,
+ shape,
+ shape,
+ shape,
+ dim,
+ &input_buffer,
+ 0,
+ &indices_buffer,
+ 0,
+ &output,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+ read_to_vec(&output, left.len())
+}
+
+#[test]
+fn index_add() {
+ let left = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
+ let right = vec![1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0];
+ let indices = vec![0u32, 1, 0, 1, 0, 1];
+ let shape = vec![6];
+
+ // u32, f32
+ {
+ let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_f32");
+ assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
+ }
+
+ // u32, f16
+ {
+ let left = left.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
+ let right = right.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
+ let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_f16");
+ assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
+ }
+
+ // u32, bf16
+ {
+ let left = left.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
+ let right = right.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
+ let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_bf16");
+ assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
+ }
+
+ // u8, f32
+ {
+ let indices = indices.iter().map(|v| *v as u8).collect::<Vec<_>>();
+ let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_f32");
+ assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
+ }
+
+ // u8, f16
+ {
+ let indices = indices.iter().map(|v| *v as u8).collect::<Vec<_>>();
+ let left = left.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
+ let right = right.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
+ let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_f16");
+ assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
+ }
+
+ // u8, bf16
+ {
+ let indices = indices.iter().map(|v| *v as u8).collect::<Vec<_>>();
+ let left = left.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
+ let right = right.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
+ let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_bf16");
+ assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
+ }
+
+ // i64, f32
+ {
+ let indices = indices.iter().map(|v| *v as i64).collect::<Vec<_>>();
+ let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_f32");
+ assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
+ }
+
+ // i64, f16
+ {
+ let indices = indices.iter().map(|v| *v as i64).collect::<Vec<_>>();
+ let left = left.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
+ let right = right.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
+ let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_f16");
+ assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
+ }
+
+ // i64, bf16
+ {
+ let indices = indices.iter().map(|v| *v as i64).collect::<Vec<_>>();
+ let left = left.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
+ let right = right.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
+ let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_bf16");
+ assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]);
+ }
+}