summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/metal_backend.rs49
1 files changed, 42 insertions, 7 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index b26477fc..21a8967b 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -951,14 +951,49 @@ impl BackendStorage for MetalStorage {
fn index_add(
&self,
- _: &Layout,
- _: &Self,
- _: &Layout,
- _: &Self,
- _: &Layout,
- _: usize,
+ l: &Layout,
+ ids: &Self,
+ ids_l: &Layout,
+ src: &Self,
+ src_l: &Layout,
+ dim: usize,
) -> Result<Self> {
- crate::bail!("index_add metal")
+ let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
+ self.copy_strided_src(&mut acc, 0, l)?;
+ let (ids_offset, _) = match ids_l.contiguous_offsets() {
+ Some(o12) => o12,
+ None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
+ };
+ let src_offset = match src_l.contiguous_offsets() {
+ Some((o1, _)) => o1,
+ None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
+ };
+ let name = match (ids.dtype, self.dtype) {
+ (DType::U32, DType::F32) => "ia_u32_f32",
+ _ => Err(MetalError::UnexpectedDType {
+ msg: "index-add ids should be u8/u32/i64",
+ expected: DType::U32,
+ got: ids.dtype(),
+ })?,
+ };
+ let command_buffer = self.device.command_buffer()?;
+ candle_metal_kernels::call_index_add(
+ &self.device.device,
+ &command_buffer,
+ &self.device.kernels,
+ name,
+ src_l.dims(),
+ l.dims(),
+ ids_l.dims(),
+ dim,
+ &src.buffer,
+ src_offset * src.dtype.size_in_bytes(),
+ &ids.buffer,
+ ids_offset * ids.dtype.size_in_bytes(),
+ &acc.buffer,
+ )
+ .map_err(MetalError::from)?;
+ Ok(acc)
}
fn matmul(
&self,