summaryrefslogtreecommitdiff
path: root/candle-core/tests/tensor_tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/tests/tensor_tests.rs')
-rw-r--r--candle-core/tests/tensor_tests.rs9
1 files changed, 9 insertions, 0 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 2e867b26..a50f3a6c 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -680,6 +680,15 @@ fn index_select(device: &Device) -> Result<()> {
[3.0, 4.0, 5.0],
]
);
+
+ // Test when selecting dim > 0 with ids size different from elem count of
+ // target dim in source/input.
+ let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
+ let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
+ assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
+ let hs = t.index_select(&ids, 1)?;
+ assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
+
Ok(())
}