summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-12-18 10:46:01 +0100
committerNicolas Patry <patry.nicolas@protonmail.com>2023-12-18 10:46:01 +0100
commit8bd3d6b94bb2449d056d38d5a42c8e9c762f2d7e (patch)
tree57c07c91d8924e26c867883d39cf26eebd535b97
parent6a3ca7da0cfb06e80d5c75ee98a1291843092e06 (diff)
downloadcandle-8bd3d6b94bb2449d056d38d5a42c8e9c762f2d7e.tar.gz
candle-8bd3d6b94bb2449d056d38d5a42c8e9c762f2d7e.tar.bz2
candle-8bd3d6b94bb2449d056d38d5a42c8e9c762f2d7e.zip
Index add.
-rw-r--r--candle-core/src/metal_backend.rs49
-rw-r--r--candle-metal-kernels/src/indexing.metal111
-rw-r--r--candle-metal-kernels/src/lib.rs54
3 files changed, 151 insertions, 63 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index b26477fc..21a8967b 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -951,14 +951,49 @@ impl BackendStorage for MetalStorage {
fn index_add(
&self,
- _: &Layout,
- _: &Self,
- _: &Layout,
- _: &Self,
- _: &Layout,
- _: usize,
+ l: &Layout,
+ ids: &Self,
+ ids_l: &Layout,
+ src: &Self,
+ src_l: &Layout,
+ dim: usize,
) -> Result<Self> {
- crate::bail!("index_add metal")
+ let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
+ self.copy_strided_src(&mut acc, 0, l)?;
+ let (ids_offset, _) = match ids_l.contiguous_offsets() {
+ Some(o12) => o12,
+ None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
+ };
+ let src_offset = match src_l.contiguous_offsets() {
+ Some((o1, _)) => o1,
+ None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
+ };
+ let name = match (ids.dtype, self.dtype) {
+ (DType::U32, DType::F32) => "ia_u32_f32",
+ _ => Err(MetalError::UnexpectedDType {
+ msg: "index-add ids should be u8/u32/i64",
+ expected: DType::U32,
+ got: ids.dtype(),
+ })?,
+ };
+ let command_buffer = self.device.command_buffer()?;
+ candle_metal_kernels::call_index_add(
+ &self.device.device,
+ &command_buffer,
+ &self.device.kernels,
+ name,
+ src_l.dims(),
+ l.dims(),
+ ids_l.dims(),
+ dim,
+ &src.buffer,
+ src_offset * src.dtype.size_in_bytes(),
+ &ids.buffer,
+ ids_offset * ids.dtype.size_in_bytes(),
+ &acc.buffer,
+ )
+ .map_err(MetalError::from)?;
+ Ok(acc)
}
fn matmul(
&self,
diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal
index 72a3a348..63357428 100644
--- a/candle-metal-kernels/src/indexing.metal
+++ b/candle-metal-kernels/src/indexing.metal
@@ -122,48 +122,47 @@ kernel void NAME( \
scatter_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \
}
-
-template <typename T, typename I>
-void index_add(
- device I *ids [[buffer(0)]],
- device T *inp [[buffer(1)]],
- device T *out [[buffer(2)]],
-
- constant uint &ids_dim_size,
- constant uint &left_size,
- constant uint &dst_dim_size,
- constant uint &right_size,
-
- uint gid [[ thread_position_in_grid ]] \
-) {
-
- if (gid >= left_size * right_size) {
- return;
- }
-
- const uint i = gid;
- const uint pre = i / right_size;
- const uint post = i % right_size;
-
- for (uint j = 0; j < ids_dim_size; j++) {
- const uint idx = ids[j];
- const uint src_i = (pre * ids_dim_size + j) * right_size + post;
- const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
- out[dst_i] += inp[src_i];
+template<typename TYPENAME, typename INDEX_TYPENAME>
+METAL_FUNC void index_add(
+ constant size_t &dst_size,
+ constant size_t &left_size,
+ constant size_t &src_dim_size,
+ constant size_t &right_size,
+ constant size_t &dst_dim_size,
+ constant size_t &ids_dim_size,
+ const device TYPENAME *input,
+ const device INDEX_TYPENAME *input_ids,
+ device TYPENAME *output,
+ uint tid [[ thread_position_in_grid ]]
+) {
+ if (tid >= dst_size) {
+ return;
+ }
+ const size_t right_rank_i = tid % right_size;
+ const size_t left_rank_i = tid / right_size;
+ for (unsigned int j = 0; j < ids_dim_size; ++j) {
+ const INDEX_TYPENAME idx = input_ids[j];
+ const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
+ const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
+ output[dst_i] += input[src_i];
}
}
-#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
-kernel void FN_NAME( \
- device INDEX_TYPENAME *ids [[buffer(0)]], \
- device TYPENAME *inp [[buffer(1)]], \
- device TYPENAME *out [[buffer(2)]], \
- constant uint &ids_dim_size, \
- constant uint &left_size, \
- constant uint &dst_dim_size, \
- constant uint &right_size, \
- uint gid [[ thread_position_in_grid ]] \
-) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \
+# define INDEX_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
+kernel void NAME( \
+ constant size_t &dst_size, \
+ constant size_t &left_size, \
+ constant size_t &src_dim_size, \
+ constant size_t &right_size, \
+ constant size_t &dst_dim_size, \
+ constant size_t &ids_dim_size, \
+ const device TYPENAME *input, \
+ const device INDEX_TYPENAME *input_ids, \
+ device TYPENAME *output, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ index_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, ids_dim_size, input, input_ids, output, tid); \
+}
INDEX_OP(is_u32_f32, uint, float)
@@ -175,25 +174,25 @@ SCATTER_ADD_OP(sa_u32_f16, uint, half)
#if __METAL_VERSION__ >= 310
-IA_OP(bfloat, int64_t, ia_i64_bf16)
-IA_OP(bfloat, uint32_t, ia_u32_bf16)
-IA_OP(bfloat, uint8_t, ia_u8_bf16)
+INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
+INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
+INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
#endif
-IA_OP(half, uint32_t, ia_u32_f16)
-IA_OP(half, uint8_t, ia_u8_f16)
+INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
+INDEX_ADD_OP(ia_u8_f16, uint8_t, half)
-IA_OP(float, int64_t, ia_i64_f32)
-IA_OP(uint8_t, int64_t, ia_i64_u8)
-IA_OP(int64_t, int64_t, ia_i64_i64)
-IA_OP(uint32_t, int64_t, ia_i64_u32)
+INDEX_ADD_OP(ia_i64_f32, int64_t, float)
+INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t)
+INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t)
+INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t)
-IA_OP(float, uint32_t, ia_u32_f32)
-IA_OP(uint8_t, uint32_t, ia_u32_u8)
-IA_OP(int64_t, uint32_t, ia_u32_i64)
-IA_OP(uint32_t, uint32_t, ia_u32_u32)
+INDEX_ADD_OP(ia_u32_f32, uint32_t, float)
+INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t)
+INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t)
+INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t)
-IA_OP(float, uint8_t, ia_u8_f32)
-IA_OP(uint8_t, uint8_t, ia_u8_u8)
-IA_OP(uint32_t, uint8_t, ia_u8_u32)
-IA_OP(int64_t, uint8_t, ia_u8_i64)
+INDEX_ADD_OP(ia_u8_f32, uint8_t, float)
+INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
+INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
+INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t)
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index ddc04d05..0bd7d8cb 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1114,6 +1114,60 @@ pub fn call_scatter_add(
Ok(())
}
+pub fn call_index_add(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: &'static str,
+ src_shape: &[usize],
+ dst_shape: &[usize],
+ ids_shape: &[usize],
+ dim: usize,
+ input: &Buffer,
+ input_offset: usize,
+ ids: &Buffer,
+ ids_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let left_size: usize = src_shape[..dim].iter().product();
+ let right_size: usize = src_shape[dim + 1..].iter().product();
+ let src_dim_size = src_shape[dim];
+ let dst_el = left_size * right_size;
+ let dst_dim_size = dst_shape[dim];
+ let ids_dim_size = ids_shape[0];
+
+ let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+
+ encoder.wait_for_fence(&kernels.fence);
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (
+ dst_el,
+ left_size,
+ src_dim_size,
+ right_size,
+ dst_dim_size,
+ ids_dim_size,
+ (input, input_offset),
+ (ids, ids_offset),
+ output
+ )
+ );
+
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
+
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(ids, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.update_fence(&kernels.fence);
+ encoder.end_encoding();
+ Ok(())
+}
+
#[derive(Debug, PartialEq)]
pub enum Value {
USize(usize),