diff options
-rw-r--r-- | candle-core/src/metal_backend.rs | 22 | ||||
-rw-r--r-- | candle-metal-kernels/src/indexing.metal | 38 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 116 |
3 files changed, 160 insertions, 16 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index a6513b1c..3bee7657 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -1242,9 +1242,29 @@ impl BackendStorage for MetalStorage { None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?, }; let name = match (ids.dtype, self.dtype) { + (DType::I64, DType::BF16) => "ia_i64_bf16", + (DType::I64, DType::F16) => "ia_i64_f16", + (DType::I64, DType::F32) => "ia_i64_f32", + (DType::I64, DType::I64) => "ia_i64_i64", + (DType::I64, DType::U32) => "ia_i64_u32", + (DType::I64, DType::U8) => "ia_i64_u8", + + (DType::U32, DType::BF16) => "ia_u32_bf16", + (DType::U32, DType::F16) => "ia_u32_f16", (DType::U32, DType::F32) => "ia_u32_f32", + (DType::U32, DType::I64) => "ia_u32_i64", + (DType::U32, DType::U32) => "ia_u32_u32", + (DType::U32, DType::U8) => "ia_u32_u8", + + (DType::U8, DType::BF16) => "ia_u8_bf16", + (DType::U8, DType::F16) => "ia_u8_f16", + (DType::U8, DType::F32) => "ia_u8_f32", + (DType::U8, DType::I64) => "ia_u8_i64", + (DType::U8, DType::U32) => "ia_u8_u32", + (DType::U8, DType::U8) => "ia_u8_u8", + _ => Err(MetalError::UnexpectedDType { - msg: "index-add ids should be u32", + msg: "index-add ids should be u8/u32/i64", expected: DType::U32, got: ids.dtype(), })?, diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index f6b81be0..65491759 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -167,6 +167,10 @@ kernel void NAME( \ INDEX_OP(is_u32_f32, uint, float) INDEX_OP(is_u32_f16, uint, half) +#if defined(__HAVE_BFLOAT__) +INDEX_OP(is_u32_bf16, uint32_t, bfloat) +INDEX_OP(is_u8_bf16, uint8_t, bfloat) +#endif GATHER_OP(gather_u32_f32, uint, float) GATHER_OP(gather_u32_f16, uint, half) @@ -177,34 +181,38 @@ SCATTER_ADD_OP(sa_i64_f32, int64_t, float) SCATTER_ADD_OP(sa_u32_f16, uint32_t, half) SCATTER_ADD_OP(sa_u8_f16, uint8_t, half) SCATTER_ADD_OP(sa_i64_f16, int64_t, half) - #if defined(__HAVE_BFLOAT__) -INDEX_OP(is_u32_bf16, uint32_t, bfloat) -INDEX_OP(is_u8_bf16, uint8_t, bfloat) - -INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) -INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) -INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat) - SCATTER_ADD_OP(sa_u32_bf16, uint32_t, bfloat) SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat) SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat) #endif -INDEX_ADD_OP(ia_u32_f16, uint32_t, half) -INDEX_ADD_OP(ia_u8_f16, uint8_t, half) - +// i64 +INDEX_ADD_OP(ia_i64_f16, int64_t, half) INDEX_ADD_OP(ia_i64_f32, int64_t, float) -INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t) INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t) INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t) +INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t) +#if defined(__HAVE_BFLOAT__) +INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) +#endif +// u32 +INDEX_ADD_OP(ia_u32_f16, uint32_t, half) INDEX_ADD_OP(ia_u32_f32, uint32_t, float) -INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t) INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t) INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t) +INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t) +#if defined(__HAVE_BFLOAT__) +INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) +#endif +// u8 +INDEX_ADD_OP(ia_u8_f16, uint8_t, half) INDEX_ADD_OP(ia_u8_f32, uint8_t, float) -INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t) -INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t) INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t) +INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t) +INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t) +#if defined(__HAVE_BFLOAT__) +INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat) +#endif
\ No newline at end of file diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index b2f1d723..a34882d3 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1252,3 +1252,119 @@ fn scatter_add() { } } } + +fn run_index_add<T: Clone, I: Clone + std::fmt::Debug>( + left: &[T], + right: &[T], + indices: &[I], + shape: &[usize], + dim: usize, + name: &'static str, +) -> Vec<T> { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input_buffer = new_buffer(&device, right); + let output = new_buffer(&device, left); + let indices_buffer = new_buffer(&device, indices); + call_index_add( + &device, + command_buffer, + &kernels, + name, + shape, + shape, + shape, + dim, + &input_buffer, + 0, + &indices_buffer, + 0, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + read_to_vec(&output, left.len()) +} + +#[test] +fn index_add() { + let left = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let right = vec![1.0f32, 1.0, 1.0, 1.0, 1.0, 1.0]; + let indices = vec![0u32, 1, 0, 1, 0, 1]; + let shape = vec![6]; + + // u32, f32 + { + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_f32"); + assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // u32, f16 + { + let left = left.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>(); + let right = right.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_f16"); + assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // u32, bf16 + { + let left = left.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>(); + let right = right.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u32_bf16"); + assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // u8, f32 + { + let indices = indices.iter().map(|v| *v as u8).collect::<Vec<_>>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_f32"); + assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // u8, f16 + { + let indices = indices.iter().map(|v| *v as u8).collect::<Vec<_>>(); + let left = left.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>(); + let right = right.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_f16"); + assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // u8, bf16 + { + let indices = indices.iter().map(|v| *v as u8).collect::<Vec<_>>(); + let left = left.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>(); + let right = right.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_u8_bf16"); + assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // i64, f32 + { + let indices = indices.iter().map(|v| *v as i64).collect::<Vec<_>>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_f32"); + assert_eq!(results, vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // i64, f16 + { + let indices = indices.iter().map(|v| *v as i64).collect::<Vec<_>>(); + let left = left.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>(); + let right = right.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_f16"); + assert_eq!(approx_f16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } + + // i64, bf16 + { + let indices = indices.iter().map(|v| *v as i64).collect::<Vec<_>>(); + let left = left.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>(); + let right = right.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>(); + let results = run_index_add(&left, &right, &indices, &shape, 0, "ia_i64_bf16"); + assert_eq!(approx_bf16(results, 4), vec![4.0, 5.0, 3.0, 4.0, 5.0, 6.0]); + } +} |