diff options
Diffstat (limited to 'candle-nn/src/embedding.rs')
-rw-r--r-- | candle-nn/src/embedding.rs | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs index f4ba88e7..918c1805 100644 --- a/candle-nn/src/embedding.rs +++ b/candle-nn/src/embedding.rs @@ -18,8 +18,10 @@ impl Embedding { pub fn embeddings(&self) -> &Tensor { &self.embeddings } +} - pub fn forward(&self, indexes: &Tensor) -> Result<Tensor> { +impl crate::Module for Embedding { + fn forward(&self, indexes: &Tensor) -> Result<Tensor> { let mut final_dims = indexes.dims().to_vec(); final_dims.push(self.hidden_size); let indexes = indexes.flatten_all()?; |