diff options
Diffstat (limited to 'candle-core/tests/tensor_tests.rs')
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 2e867b26..a50f3a6c 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -680,6 +680,15 @@ fn index_select(device: &Device) -> Result<()> { [3.0, 4.0, 5.0], ] ); + + // Test when selecting dim > 0 with ids size different from elem count of + // target dim in source/input. + let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?; + let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?; + assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]); + let hs = t.index_select(&ids, 1)?; + assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]); + Ok(()) } |