summaryrefslogtreecommitdiff
path: root/candle-pyo3/py_src/candle/nn/sparse.py
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-10-06 20:01:07 +0200
committerGitHub <noreply@github.com>2023-10-06 19:01:07 +0100
commit904bbdae65d69aac0c54c29eef744ca5e69c6733 (patch)
tree8e191c2cb8cac91d76d2bb9875a60d4ccfe9dbf5 /candle-pyo3/py_src/candle/nn/sparse.py
parentb0442eff8a696d1faba10e23ba645eb11e385116 (diff)
downloadcandle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.gz
candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.bz2
candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.zip
Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations * Add `state_dict` and `load_state_dict` functionality * Move modules around and create `candle.nn.Linear` * Add `nn.Embedding` and `nn.LayerNorm` * Add BERT implementation * Batch q-matmul * Automatically dequantize `QTensors` if a `Tensor` is expected * Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality * Unittests for `Module`, `Tensor` and `candle.utils` * Add `pytorch` like slicing to `Tensor` * Cleanup and BERT fixes * `black` formatting + unit-test for `nn.Linear` * Refactor slicing implementation
Diffstat (limited to 'candle-pyo3/py_src/candle/nn/sparse.py')
-rw-r--r--candle-pyo3/py_src/candle/nn/sparse.py39
1 files changed, 39 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/nn/sparse.py b/candle-pyo3/py_src/candle/nn/sparse.py
new file mode 100644
index 00000000..386f8081
--- /dev/null
+++ b/candle-pyo3/py_src/candle/nn/sparse.py
@@ -0,0 +1,39 @@
+from .module import Module
+from typing import Optional, Tuple, Any
+from candle import Tensor
+import candle
+
+
+class Embedding(Module):
+ """A simple lookup table that stores embeddings of a fixed dictionary and size.
+
+ This module is often used to store word embeddings and retrieve them using indices.
+ The input to the module is a list of indices, and the output is the corresponding
+ word embeddings.
+
+ Args:
+ num_embeddings (int): size of the dictionary of embeddings
+ embedding_dim (int): the size of each embedding vector
+
+ Attributes:
+ weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
+ initialized from :math:`\mathcal{N}(0, 1)`
+
+ Shape:
+ - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract
+ - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int, device=None) -> None:
+ factory_kwargs = {"device": device}
+ super().__init__()
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ self.weight = candle.randn((num_embeddings, embedding_dim), **factory_kwargs)
+
+ def forward(self, indexes: Tensor) -> Tensor:
+ final_dims = list(indexes.shape)
+ final_dims.append(self.embedding_dim)
+ indexes = indexes.flatten_all()
+ values = self.weight.index_select(indexes, 0)
+ return values.reshape(final_dims)