summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/tensor.rs7
-rw-r--r--candle-core/tests/tensor_tests.rs274
2 files changed, 278 insertions, 3 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 7dd24abf..e7355aad 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1520,14 +1520,15 @@ impl Tensor {
/// # Arguments
///
/// * `self` - The input tensor.
- /// * `indexes` - The indices of elements to gather, this should have the same shape as `self`
- /// but can have a different number of elements on the target dimension.
+ /// * `indexes` - The indices of elements to gather, this should have same number of dimensions as `self`
+ /// and indexes.dims()[d] <= self.dims()[d] for all dimensions d != dim
/// * `dim` - the target dimension.
///
/// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on
/// dimension `dim` by the values in `indexes`.
pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "gather")?;
+
let self_dims = self.dims();
let indexes_dims = indexes.dims();
let mismatch = if indexes_dims.len() != self_dims.len() {
@@ -1535,7 +1536,7 @@ impl Tensor {
} else {
let mut mismatch = false;
for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
- if i != dim && d1 != d2 {
+ if i != dim && d1 < d2 {
mismatch = true;
break;
}
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index e0cea15c..e3246a33 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -1047,6 +1047,280 @@ fn gather(device: &Device) -> Result<()> {
let ids = Tensor::new(&[[0u32, 2u32, 0u32], [0u32, 1u32, 1u32]], device)?;
let hs = t.gather(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]);
+
+ // Random data
+
+ // Dim: 0
+ let t = Tensor::new(
+ &[
+ [
+ [108_f32, -47., 16., -56., -83., -130., 210.],
+ [253., 95., 151., 228., -210., -123., -127.],
+ [-9., -217., 2., -78., 163., 245., -204.],
+ [-246., 79., -238., 88., -226., -184., 171.],
+ [8., -48., -153., 234., -34., 166., -153.],
+ [124., 0., -10., -61., -242., -15., -238.],
+ ],
+ [
+ [12., -64., -199., 244., -240., 156., -128.],
+ [173., -57., 4., -198., 233., -110., 238.],
+ [95., 82., 0., 240., 53., -211., 209.],
+ [-122., 167., -212., 227., -144., 61., 118.],
+ [-63., -146., 200., 244., 168., -167., 116.],
+ [-125., -147., 110., -253., -178., -250., -18.],
+ ],
+ [
+ [57., 86., -50., 56., 92., 205., -78.],
+ [-137., -156., -18., 248., -61., -239., 14.],
+ [-248., -30., -50., -70., -251., 250., -83.],
+ [-221., 67., 72., 59., -24., -154., 232.],
+ [-144., -23., -74., 5., 93., 171., 205.],
+ [46., -77., -38., -226., 246., 161., -17.],
+ ],
+ [
+ [-153., -231., -236., 161., 126., 2., -22.],
+ [-229., -41., 209., 164., 234., 160., 57.],
+ [223., 254., -186., -162., -46., -160., -102.],
+ [65., 30., 213., -253., 59., 224., -154.],
+ [-82., -203., -177., 17., 31., -256., -246.],
+ [176., -135., -65., 54., -56., 210., 76.],
+ ],
+ [
+ [-10., -245., 168., 124., -14., -33., -178.],
+ [25., -43., -39., 132., -89., 169., 179.],
+ [187., -215., 32., -133., 87., -7., -168.],
+ [-224., -215., -5., -230., -58., -162., 128.],
+ [158., -137., -122., -100., -202., -83., 136.],
+ [30., -185., -144., 250., 209., -40., 127.],
+ ],
+ [
+ [-196., 108., -245., 122., 146., -228., 62.],
+ [-1., -66., 160., 137., 13., -172., -21.],
+ [244., 199., -164., 28., 119., -175., 198.],
+ [-62., 253., -162., 195., -95., -230., -211.],
+ [123., -72., -26., -107., -139., 64., 245.],
+ [11., -126., -182., 108., -12., 184., -127.],
+ ],
+ [
+ [-159., 126., 176., 161., 73., -111., -138.],
+ [-187., 214., -217., -33., -223., -201., -212.],
+ [-61., -120., -166., -172., -95., 53., 196.],
+ [-33., 86., 134., -152., 154., -53., 74.],
+ [186., -28., -154., -174., 141., -109., 217.],
+ [82., 35., 252., 145., 181., 74., -87.],
+ ],
+ ],
+ device,
+ )?;
+
+ let ids = Tensor::new(
+ &[
+ [
+ [6_u32, 6, 4, 3, 4, 4, 6],
+ [3, 3, 2, 4, 4, 4, 6],
+ [3, 3, 0, 2, 4, 6, 4],
+ [2, 5, 1, 2, 6, 6, 1],
+ [2, 1, 6, 5, 3, 2, 3],
+ [6, 1, 0, 1, 0, 2, 6],
+ ],
+ [
+ [4, 6, 4, 3, 3, 3, 2],
+ [4, 3, 2, 4, 4, 4, 6],
+ [2, 3, 0, 2, 4, 6, 4],
+ [6, 5, 1, 2, 6, 6, 1],
+ [4, 1, 6, 5, 3, 2, 3],
+ [1, 1, 0, 1, 0, 2, 6],
+ ],
+ [
+ [3, 6, 4, 3, 3, 3, 2],
+ [2, 3, 2, 4, 4, 4, 6],
+ [4, 3, 0, 2, 4, 6, 4],
+ [0, 5, 1, 2, 6, 6, 1],
+ [6, 1, 6, 5, 3, 2, 3],
+ [4, 1, 0, 1, 0, 2, 6],
+ ],
+ [
+ [0, 6, 4, 3, 3, 3, 2],
+ [5, 3, 2, 4, 4, 4, 6],
+ [0, 3, 0, 2, 4, 6, 4],
+ [3, 5, 1, 2, 6, 6, 1],
+ [0, 1, 6, 5, 3, 2, 3],
+ [3, 1, 0, 1, 0, 2, 6],
+ ],
+ ],
+ device,
+ )?;
+
+ let hs = t.gather(&ids, 0)?;
+ assert_eq!(
+ hs.to_vec3::<f32>()?,
+ &[
+ [
+ [-159_f32, 126., 168., 161., -14., -33., -138.],
+ [-229., -41., -18., 132., -89., 169., -212.],
+ [223., 254., 2., -70., 87., 53., -168.],
+ [-221., 253., -212., 59., 154., -53., 118.],
+ [-144., -146., -154., -107., 31., 171., -246.],
+ [82., -147., -10., -253., -242., 161., -87.]
+ ],
+ [
+ [-10., 126., 168., 161., 126., 2., -78.],
+ [25., -41., -18., 132., -89., 169., -212.],
+ [-248., 254., 2., -70., 87., 53., -168.],
+ [-33., 253., -212., 59., 154., -53., 118.],
+ [158., -146., -154., -107., 31., 171., -246.],
+ [-125., -147., -10., -253., -242., 161., -87.]
+ ],
+ [
+ [-153., 126., 168., 161., 126., 2., -78.],
+ [-137., -41., -18., 132., -89., 169., -212.],
+ [187., 254., 2., -70., 87., 53., -168.],
+ [-246., 253., -212., 59., 154., -53., 118.],
+ [186., -146., -154., -107., 31., 171., -246.],
+ [30., -147., -10., -253., -242., 161., -87.]
+ ],
+ [
+ [108., 126., 168., 161., 126., 2., -78.],
+ [-1., -41., -18., 132., -89., 169., -212.],
+ [-9., 254., 2., -70., 87., 53., -168.],
+ [65., 253., -212., 59., 154., -53., 118.],
+ [8., -146., -154., -107., 31., 171., -246.],
+ [176., -147., -10., -253., -242., 161., -87.]
+ ]
+ ]
+ );
+
+ // Dim: 1
+ let t = Tensor::new(
+ &[
+ [
+ [-117_f32, -175., 69., -163.],
+ [200., 242., -21., -67.],
+ [179., 150., -126., -75.],
+ [-118., 38., -138., -13.],
+ [-221., 136., -185., 180.],
+ [58., 182., -204., -149.],
+ ],
+ [
+ [3., -148., -58., -154.],
+ [-43., 45., -108., 4.],
+ [-69., -249., -71., -21.],
+ [80., 110., -152., -235.],
+ [-88., 7., 92., -250.],
+ [-186., 207., -242., 98.],
+ ],
+ [
+ [238., 19., 64., -242.],
+ [-150., -97., 218., 58.],
+ [111., -233., 204., -212.],
+ [-242., -232., 83., 42.],
+ [153., 62., -251., 219.],
+ [-117., 36., -119., 10.],
+ ],
+ [
+ [215., 159., -169., -27.],
+ [-83., 101., -88., 169.],
+ [-205., 93., 225., -64.],
+ [-162., 240., 214., 23.],
+ [-112., 6., 21., 245.],
+ [-38., 113., 93., 215.],
+ ],
+ [
+ [91., -188., -148., 101.],
+ [74., 203., -35., 55.],
+ [-116., -130., -153., -96.],
+ [58., 22., -45., -194.],
+ [-221., -134., 73., 159.],
+ [-203., -254., 31., 235.],
+ ],
+ [
+ [105., -53., 61., 186.],
+ [-195., 234., 75., -1.],
+ [51., 139., 160., -108.],
+ [-173., -167., 161., 19.],
+ [83., -246., 156., -222.],
+ [109., 39., -149., 137.],
+ ],
+ ],
+ device,
+ )?;
+
+ let ids = Tensor::new(
+ &[
+ [[4_u32, 4, 4, 2]],
+ [[0, 4, 4, 3]],
+ [[1, 5, 3, 4]],
+ [[0, 3, 3, 2]],
+ [[1, 1, 5, 2]],
+ [[1, 4, 5, 4]],
+ ],
+ device,
+ )?;
+
+ let hs = t.gather(&ids, 1)?;
+ assert_eq!(
+ hs.to_vec3::<f32>()?,
+ &[
+ [[-221., 136., -185., -75.]],
+ [[3., 7., 92., -235.]],
+ [[-150., 36., 83., 219.]],
+ [[215., 240., 214., -64.]],
+ [[74., 203., 31., -96.]],
+ [[-195., -246., -149., -222.]]
+ ]
+ );
+
+ // Dim: 2
+ let t = Tensor::new(
+ &[
+ [[-162_f32, 202.], [-126., -39.], [35., -65.], [1., 80.]],
+ [[37., 248.], [-191., 89.], [117., -40.], [-217., 220.]],
+ ],
+ device,
+ )?;
+
+ let ids = Tensor::new(&[[[1_u32], [0], [1], [1]], [[0], [1], [0], [1]]], device)?;
+
+ let hs = t.gather(&ids, 2)?;
+ assert_eq!(
+ hs.to_vec3::<f32>()?,
+ &[
+ [[202.], [-126.], [-65.], [80.]],
+ [[37.], [89.], [117.], [220.]]
+ ]
+ );
+
+ let t = Tensor::new(
+ &[
+ [[-21_f32, -197.], [194., 122.]],
+ [[255., -106.], [-191., 250.]],
+ [[33., -117.], [43., 10.]],
+ [[-130., 238.], [-217., -92.]],
+ ],
+ device,
+ )?;
+
+ let ids = Tensor::new(
+ &[
+ [[0_u32, 1], [1, 0]],
+ [[1, 0], [0, 1]],
+ [[0, 1], [0, 1]],
+ [[1, 0], [1, 0]],
+ ],
+ device,
+ )?;
+
+ let hs = t.gather(&ids, 2)?;
+ assert_eq!(
+ hs.to_vec3::<f32>()?,
+ &[
+ [[-21., -197.], [122., 194.]],
+ [[-106., 255.], [-191., 250.]],
+ [[33., -117.], [43., 10.]],
+ [[238., -130.], [-92., -217.]]
+ ]
+ );
+
Ok(())
}