diff options
author | Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> | 2023-10-06 20:01:07 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-06 19:01:07 +0100 |
commit | 904bbdae65d69aac0c54c29eef744ca5e69c6733 (patch) | |
tree | 8e191c2cb8cac91d76d2bb9875a60d4ccfe9dbf5 /candle-pyo3/py_src/candle/nn/sparse.py | |
parent | b0442eff8a696d1faba10e23ba645eb11e385116 (diff) | |
download | candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.gz candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.bz2 candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.zip |
Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations
* Add `state_dict` and `load_state_dict` functionality
* Move modules around and create `candle.nn.Linear`
* Add `nn.Embedding` and `nn.LayerNorm`
* Add BERT implementation
* Batch q-matmul
* Automatically dequantize `QTensors` if a `Tensor` is expected
* Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality
* Unittests for `Module`, `Tensor` and `candle.utils`
* Add `pytorch` like slicing to `Tensor`
* Cleanup and BERT fixes
* `black` formatting + unit-test for `nn.Linear`
* Refactor slicing implementation
Diffstat (limited to 'candle-pyo3/py_src/candle/nn/sparse.py')
-rw-r--r-- | candle-pyo3/py_src/candle/nn/sparse.py | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/nn/sparse.py b/candle-pyo3/py_src/candle/nn/sparse.py new file mode 100644 index 00000000..386f8081 --- /dev/null +++ b/candle-pyo3/py_src/candle/nn/sparse.py @@ -0,0 +1,39 @@ +from .module import Module +from typing import Optional, Tuple, Any +from candle import Tensor +import candle + + +class Embedding(Module): + """A simple lookup table that stores embeddings of a fixed dictionary and size. + + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + + Attributes: + weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) + initialized from :math:`\mathcal{N}(0, 1)` + + Shape: + - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, device=None) -> None: + factory_kwargs = {"device": device} + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.weight = candle.randn((num_embeddings, embedding_dim), **factory_kwargs) + + def forward(self, indexes: Tensor) -> Tensor: + final_dims = list(indexes.shape) + final_dims.append(self.embedding_dim) + indexes = indexes.flatten_all() + values = self.weight.index_select(indexes, 0) + return values.reshape(final_dims) |