summaryrefslogtreecommitdiff
path: root/candle-nn/src/embedding.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src/embedding.rs')
-rw-r--r--candle-nn/src/embedding.rs4
1 files changed, 3 insertions, 1 deletions
diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs
index f4ba88e7..918c1805 100644
--- a/candle-nn/src/embedding.rs
+++ b/candle-nn/src/embedding.rs
@@ -18,8 +18,10 @@ impl Embedding {
pub fn embeddings(&self) -> &Tensor {
&self.embeddings
}
+}
- pub fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
+impl crate::Module for Embedding {
+ fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
let mut final_dims = indexes.dims().to_vec();
final_dims.push(self.hidden_size);
let indexes = indexes.flatten_all()?;