1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
|
from .module import Module
from typing import Optional, Tuple, Any
from candle import Tensor
import candle
class Embedding(Module):
"""A simple lookup table that stores embeddings of a fixed dictionary and size.
This module is often used to store word embeddings and retrieve them using indices.
The input to the module is a list of indices, and the output is the corresponding
word embeddings.
Args:
num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector
Attributes:
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
initialized from :math:`\mathcal{N}(0, 1)`
Shape:
- Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
"""
def __init__(self, num_embeddings: int, embedding_dim: int, device=None) -> None:
factory_kwargs = {"device": device}
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.weight = candle.randn((num_embeddings, embedding_dim), **factory_kwargs)
def forward(self, indexes: Tensor) -> Tensor:
final_dims = list(indexes.shape)
final_dims.append(self.embedding_dim)
indexes = indexes.flatten_all()
values = self.weight.index_select(indexes, 0)
return values.reshape(final_dims)
|