blob: deeba01e1cea4a7f62c2949821a57491a0e8f2ad (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
|
use candle::{Result, Tensor};
#[derive(Debug)]
pub struct Embedding {
embeddings: Tensor,
hidden_size: usize,
}
impl Embedding {
pub fn new(embeddings: Tensor, hidden_size: usize) -> Self {
Self {
embeddings,
hidden_size,
}
}
pub fn embeddings(&self) -> &Tensor {
&self.embeddings
}
pub fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
let mut final_dims = indexes.dims().to_vec();
final_dims.push(self.hidden_size);
let indexes = indexes.flatten_all()?;
let values = Tensor::embedding(&indexes, &self.embeddings)?;
let values = values.reshape(final_dims)?;
Ok(values)
}
}
|