summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-01 20:44:43 +0100
committerGitHub <noreply@github.com>2023-08-01 20:44:43 +0100
commitcc76c63202ab936c08f6a6b9dcc2756c6a227f63 (patch)
tree5efcb4607b75f023861d2633811e5925aa2e7550 /candle-nn/src
parentff876c2103bc530f9ba3bc278c5e09148c124885 (diff)
downloadcandle-cc76c63202ab936c08f6a6b9dcc2756c6a227f63.tar.gz
candle-cc76c63202ab936c08f6a6b9dcc2756c6a227f63.tar.bz2
candle-cc76c63202ab936c08f6a6b9dcc2756c6a227f63.zip
Use index-select for the embeddings as it supports backprop. (#298)
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/embedding.rs2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs
index 050123be..f4ba88e7 100644
--- a/candle-nn/src/embedding.rs
+++ b/candle-nn/src/embedding.rs
@@ -23,7 +23,7 @@ impl Embedding {
let mut final_dims = indexes.dims().to_vec();
final_dims.push(self.hidden_size);
let indexes = indexes.flatten_all()?;
- let values = Tensor::embedding(&indexes, &self.embeddings)?;
+ let values = self.embeddings.index_select(&indexes, 0)?;
let values = values.reshape(final_dims)?;
Ok(values)
}