summaryrefslogtreecommitdiff
path: root/candle-core/tests/tensor_tests.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-27 16:30:07 +0100
committerGitHub <noreply@github.com>2024-03-27 16:30:07 +0100
commitab86cd37c8fd944df351d8c7ca0e93376634a332 (patch)
tree15f9c8809c1c7c7bd449ab44fefcca24a92fb295 /candle-core/tests/tensor_tests.rs
parenta9abde5f930914ef7ef2d504728f742f80468961 (diff)
downloadcandle-ab86cd37c8fd944df351d8c7ca0e93376634a332.tar.gz
candle-ab86cd37c8fd944df351d8c7ca0e93376634a332.tar.bz2
candle-ab86cd37c8fd944df351d8c7ca0e93376634a332.zip
Support i64 in index-select on metal. (#1951)
* Support i64 in index-select on metal. * Add some testing of index-select for all dtypes.
Diffstat (limited to 'candle-core/tests/tensor_tests.rs')
-rw-r--r--candle-core/tests/tensor_tests.rs79
1 files changed, 42 insertions, 37 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index af28c1c1..8aacc05d 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -707,6 +707,8 @@ fn embeddings(device: &Device) -> Result<()> {
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let hs = t.index_select(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
+ let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
+ assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
Ok(())
}
@@ -734,44 +736,47 @@ fn index_select(device: &Device) -> Result<()> {
[9.0, 10.0, 11.0]
]
);
- let hs = t.index_select(&ids, 1)?;
- assert_eq!(
- hs.to_vec2::<f32>()?,
- &[
- [0.0, 2.0, 1.0],
- [3.0, 5.0, 4.0],
- [6.0, 8.0, 7.0],
- [9.0, 11.0, 10.0]
- ]
- );
- let hs = t.index_select(&ids, 0)?;
- assert_eq!(
- hs.to_vec2::<f32>()?,
- &[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
- );
- // Prior to https://github.com/huggingface/candle/pull/1022
- // There would be a bug where the last values in the result tensor would be set to 0.
- let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
- let hs = t.index_select(&ids, 0)?;
- assert_eq!(
- hs.to_vec2::<f32>()?,
- &[
- [0.0, 1.0, 2.0],
- [6.0, 7.0, 8.0],
- [3.0, 4.0, 5.0],
- [0.0, 1.0, 2.0],
- [6.0, 7.0, 8.0],
- [3.0, 4.0, 5.0],
- ]
- );
+ for dtype in [DType::U8, DType::U32, DType::I64] {
+ let ids = ids.to_dtype(dtype)?;
+ let hs = t.index_select(&ids, 1)?;
+ assert_eq!(
+ hs.to_vec2::<f32>()?,
+ &[
+ [0.0, 2.0, 1.0],
+ [3.0, 5.0, 4.0],
+ [6.0, 8.0, 7.0],
+ [9.0, 11.0, 10.0]
+ ]
+ );
+ let hs = t.index_select(&ids, 0)?;
+ assert_eq!(
+ hs.to_vec2::<f32>()?,
+ &[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
+ );
+ // Prior to https://github.com/huggingface/candle/pull/1022
+ // There would be a bug where the last values in the result tensor would be set to 0.
+ let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
+ let hs = t.index_select(&ids, 0)?;
+ assert_eq!(
+ hs.to_vec2::<f32>()?,
+ &[
+ [0.0, 1.0, 2.0],
+ [6.0, 7.0, 8.0],
+ [3.0, 4.0, 5.0],
+ [0.0, 1.0, 2.0],
+ [6.0, 7.0, 8.0],
+ [3.0, 4.0, 5.0],
+ ]
+ );
- // Test when selecting dim > 0 with ids size different from elem count of
- // target dim in source/input.
- let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
- let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
- assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
- let hs = t.index_select(&ids, 1)?;
- assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
+ // Test when selecting dim > 0 with ids size different from elem count of
+ // target dim in source/input.
+ let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
+ let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
+ assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
+ let hs = t.index_select(&ids, 1)?;
+ assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
+ }
Ok(())
}