summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/metal_backend.rs4
-rw-r--r--candle-core/tests/tensor_tests.rs79
-rw-r--r--candle-metal-kernels/src/indexing.metal8
3 files changed, 53 insertions, 38 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index b9e761f6..fed7db13 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -1391,6 +1391,10 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F16) => "is_u32_f16",
(DType::U32, DType::BF16) => "is_u32_bf16",
+ (DType::I64, DType::F32) => "is_i64_f32",
+ (DType::I64, DType::F16) => "is_i64_f16",
+ (DType::I64, DType::BF16) => "is_i64_bf16",
+
(left, right) => {
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
}
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index af28c1c1..8aacc05d 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -707,6 +707,8 @@ fn embeddings(device: &Device) -> Result<()> {
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let hs = t.index_select(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
+ let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
+ assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
Ok(())
}
@@ -734,44 +736,47 @@ fn index_select(device: &Device) -> Result<()> {
[9.0, 10.0, 11.0]
]
);
- let hs = t.index_select(&ids, 1)?;
- assert_eq!(
- hs.to_vec2::<f32>()?,
- &[
- [0.0, 2.0, 1.0],
- [3.0, 5.0, 4.0],
- [6.0, 8.0, 7.0],
- [9.0, 11.0, 10.0]
- ]
- );
- let hs = t.index_select(&ids, 0)?;
- assert_eq!(
- hs.to_vec2::<f32>()?,
- &[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
- );
- // Prior to https://github.com/huggingface/candle/pull/1022
- // There would be a bug where the last values in the result tensor would be set to 0.
- let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
- let hs = t.index_select(&ids, 0)?;
- assert_eq!(
- hs.to_vec2::<f32>()?,
- &[
- [0.0, 1.0, 2.0],
- [6.0, 7.0, 8.0],
- [3.0, 4.0, 5.0],
- [0.0, 1.0, 2.0],
- [6.0, 7.0, 8.0],
- [3.0, 4.0, 5.0],
- ]
- );
+ for dtype in [DType::U8, DType::U32, DType::I64] {
+ let ids = ids.to_dtype(dtype)?;
+ let hs = t.index_select(&ids, 1)?;
+ assert_eq!(
+ hs.to_vec2::<f32>()?,
+ &[
+ [0.0, 2.0, 1.0],
+ [3.0, 5.0, 4.0],
+ [6.0, 8.0, 7.0],
+ [9.0, 11.0, 10.0]
+ ]
+ );
+ let hs = t.index_select(&ids, 0)?;
+ assert_eq!(
+ hs.to_vec2::<f32>()?,
+ &[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
+ );
+ // Prior to https://github.com/huggingface/candle/pull/1022
+ // There would be a bug where the last values in the result tensor would be set to 0.
+ let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
+ let hs = t.index_select(&ids, 0)?;
+ assert_eq!(
+ hs.to_vec2::<f32>()?,
+ &[
+ [0.0, 1.0, 2.0],
+ [6.0, 7.0, 8.0],
+ [3.0, 4.0, 5.0],
+ [0.0, 1.0, 2.0],
+ [6.0, 7.0, 8.0],
+ [3.0, 4.0, 5.0],
+ ]
+ );
- // Test when selecting dim > 0 with ids size different from elem count of
- // target dim in source/input.
- let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
- let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
- assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
- let hs = t.index_select(&ids, 1)?;
- assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
+ // Test when selecting dim > 0 with ids size different from elem count of
+ // target dim in source/input.
+ let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
+ let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
+ assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
+ let hs = t.index_select(&ids, 1)?;
+ assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
+ }
Ok(())
}
diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal
index ad4a8605..762b42be 100644
--- a/candle-metal-kernels/src/indexing.metal
+++ b/candle-metal-kernels/src/indexing.metal
@@ -187,6 +187,12 @@ kernel void NAME( \
}
+INDEX_OP(is_i64_f32, int64_t, float)
+INDEX_OP(is_i64_f16, int64_t, half)
+#if defined(__HAVE_BFLOAT__)
+INDEX_OP(is_i64_bf16, int64_t, bfloat)
+#endif
+
INDEX_OP(is_u32_f32, uint32_t, float)
INDEX_OP(is_u32_f16, uint32_t, half)
#if defined(__HAVE_BFLOAT__)
@@ -242,4 +248,4 @@ INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
#if defined(__HAVE_BFLOAT__)
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
-#endif \ No newline at end of file
+#endif