diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-27 16:30:07 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-27 16:30:07 +0100 |
commit | ab86cd37c8fd944df351d8c7ca0e93376634a332 (patch) | |
tree | 15f9c8809c1c7c7bd449ab44fefcca24a92fb295 /candle-core/tests/tensor_tests.rs | |
parent | a9abde5f930914ef7ef2d504728f742f80468961 (diff) | |
download | candle-ab86cd37c8fd944df351d8c7ca0e93376634a332.tar.gz candle-ab86cd37c8fd944df351d8c7ca0e93376634a332.tar.bz2 candle-ab86cd37c8fd944df351d8c7ca0e93376634a332.zip |
Support i64 in index-select on metal. (#1951)
* Support i64 in index-select on metal.
* Add some testing of index-select for all dtypes.
Diffstat (limited to 'candle-core/tests/tensor_tests.rs')
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 79 |
1 files changed, 42 insertions, 37 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index af28c1c1..8aacc05d 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -707,6 +707,8 @@ fn embeddings(device: &Device) -> Result<()> { assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); let hs = t.index_select(&ids, 0)?; assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); + let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?; + assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); Ok(()) } @@ -734,44 +736,47 @@ fn index_select(device: &Device) -> Result<()> { [9.0, 10.0, 11.0] ] ); - let hs = t.index_select(&ids, 1)?; - assert_eq!( - hs.to_vec2::<f32>()?, - &[ - [0.0, 2.0, 1.0], - [3.0, 5.0, 4.0], - [6.0, 8.0, 7.0], - [9.0, 11.0, 10.0] - ] - ); - let hs = t.index_select(&ids, 0)?; - assert_eq!( - hs.to_vec2::<f32>()?, - &[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]] - ); - // Prior to https://github.com/huggingface/candle/pull/1022 - // There would be a bug where the last values in the result tensor would be set to 0. - let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?; - let hs = t.index_select(&ids, 0)?; - assert_eq!( - hs.to_vec2::<f32>()?, - &[ - [0.0, 1.0, 2.0], - [6.0, 7.0, 8.0], - [3.0, 4.0, 5.0], - [0.0, 1.0, 2.0], - [6.0, 7.0, 8.0], - [3.0, 4.0, 5.0], - ] - ); + for dtype in [DType::U8, DType::U32, DType::I64] { + let ids = ids.to_dtype(dtype)?; + let hs = t.index_select(&ids, 1)?; + assert_eq!( + hs.to_vec2::<f32>()?, + &[ + [0.0, 2.0, 1.0], + [3.0, 5.0, 4.0], + [6.0, 8.0, 7.0], + [9.0, 11.0, 10.0] + ] + ); + let hs = t.index_select(&ids, 0)?; + assert_eq!( + hs.to_vec2::<f32>()?, + &[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]] + ); + // Prior to https://github.com/huggingface/candle/pull/1022 + // There would be a bug where the last values in the result tensor would be set to 0. + let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?; + let hs = t.index_select(&ids, 0)?; + assert_eq!( + hs.to_vec2::<f32>()?, + &[ + [0.0, 1.0, 2.0], + [6.0, 7.0, 8.0], + [3.0, 4.0, 5.0], + [0.0, 1.0, 2.0], + [6.0, 7.0, 8.0], + [3.0, 4.0, 5.0], + ] + ); - // Test when selecting dim > 0 with ids size different from elem count of - // target dim in source/input. - let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?; - let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?; - assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]); - let hs = t.index_select(&ids, 1)?; - assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]); + // Test when selecting dim > 0 with ids size different from elem count of + // target dim in source/input. + let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?; + let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?; + assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]); + let hs = t.index_select(&ids, 1)?; + assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]); + } Ok(()) } |