diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-01 20:44:43 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-01 20:44:43 +0100 |
commit | cc76c63202ab936c08f6a6b9dcc2756c6a227f63 (patch) | |
tree | 5efcb4607b75f023861d2633811e5925aa2e7550 /candle-nn/src | |
parent | ff876c2103bc530f9ba3bc278c5e09148c124885 (diff) | |
download | candle-cc76c63202ab936c08f6a6b9dcc2756c6a227f63.tar.gz candle-cc76c63202ab936c08f6a6b9dcc2756c6a227f63.tar.bz2 candle-cc76c63202ab936c08f6a6b9dcc2756c6a227f63.zip |
Use index-select for the embeddings as it supports backprop. (#298)
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/embedding.rs | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs index 050123be..f4ba88e7 100644 --- a/candle-nn/src/embedding.rs +++ b/candle-nn/src/embedding.rs @@ -23,7 +23,7 @@ impl Embedding { let mut final_dims = indexes.dims().to_vec(); final_dims.push(self.hidden_size); let indexes = indexes.flatten_all()?; - let values = Tensor::embedding(&indexes, &self.embeddings)?; + let values = self.embeddings.index_select(&indexes, 0)?; let values = values.reshape(final_dims)?; Ok(values) } |