summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/tests/tensor_tests.rs2
-rw-r--r--candle-nn/src/embedding.rs2
2 files changed, 3 insertions, 1 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index 38336ecf..a8702df7 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -536,6 +536,8 @@ fn embeddings(device: &Device) -> Result<()> {
let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
let hs = Tensor::embedding(&ids, &t)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
+ let hs = t.index_select(&ids, 0)?;
+ assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
Ok(())
}
diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs
index 050123be..f4ba88e7 100644
--- a/candle-nn/src/embedding.rs
+++ b/candle-nn/src/embedding.rs
@@ -23,7 +23,7 @@ impl Embedding {
let mut final_dims = indexes.dims().to_vec();
final_dims.push(self.hidden_size);
let indexes = indexes.flatten_all()?;
- let values = Tensor::embedding(&indexes, &self.embeddings)?;
+ let values = self.embeddings.index_select(&indexes, 0)?;
let values = values.reshape(final_dims)?;
Ok(values)
}