//! Embedding Layer. use candle::{Result, Tensor}; #[derive(Clone, Debug)] pub struct Embedding { embeddings: Tensor, hidden_size: usize, } impl Embedding { pub fn new(embeddings: Tensor, hidden_size: usize) -> Self { Self { embeddings, hidden_size, } } pub fn embeddings(&self) -> &Tensor { &self.embeddings } /// Get the hidden size of the embedding matrix pub fn hidden_size(&self) -> usize { self.hidden_size } } impl crate::Module for Embedding { fn forward(&self, indexes: &Tensor) -> Result { let mut final_dims = indexes.dims().to_vec(); final_dims.push(self.hidden_size); let indexes = indexes.flatten_all()?; let values = self.embeddings.index_select(&indexes, 0)?; let values = values.reshape(final_dims)?; Ok(values) } } pub fn embedding(in_size: usize, out_size: usize, vb: crate::VarBuilder) -> Result { let embeddings = vb.get_with_hints( (in_size, out_size), "weight", crate::Init::Randn { mean: 0., stdev: 1., }, )?; Ok(Embedding::new(embeddings, out_size)) }