diff options
author | Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> | 2023-10-06 20:01:07 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-06 19:01:07 +0100 |
commit | 904bbdae65d69aac0c54c29eef744ca5e69c6733 (patch) | |
tree | 8e191c2cb8cac91d76d2bb9875a60d4ccfe9dbf5 /candle-pyo3/py_src/candle/nn/normalization.py | |
parent | b0442eff8a696d1faba10e23ba645eb11e385116 (diff) | |
download | candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.gz candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.bz2 candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.zip |
Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations
* Add `state_dict` and `load_state_dict` functionality
* Move modules around and create `candle.nn.Linear`
* Add `nn.Embedding` and `nn.LayerNorm`
* Add BERT implementation
* Batch q-matmul
* Automatically dequantize `QTensors` if a `Tensor` is expected
* Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality
* Unittests for `Module`, `Tensor` and `candle.utils`
* Add `pytorch` like slicing to `Tensor`
* Cleanup and BERT fixes
* `black` formatting + unit-test for `nn.Linear`
* Refactor slicing implementation
Diffstat (limited to 'candle-pyo3/py_src/candle/nn/normalization.py')
-rw-r--r-- | candle-pyo3/py_src/candle/nn/normalization.py | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/nn/normalization.py b/candle-pyo3/py_src/candle/nn/normalization.py new file mode 100644 index 00000000..67510a24 --- /dev/null +++ b/candle-pyo3/py_src/candle/nn/normalization.py @@ -0,0 +1,54 @@ +import candle +from candle import Tensor +from .module import Module +from typing import Union, List, Tuple, Optional, Any + +_shape_t = Union[int, List[int]] +import numbers + + +class LayerNorm(Module): + r"""Applies Layer Normalization over a mini-batch of inputs as described in + the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>` + + math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + """ + __constants__ = ["normalized_shape", "eps"] + normalized_shape: Tuple[int, ...] + eps: float + + def __init__( + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = tuple(normalized_shape) + self.eps = eps + + self.weight = candle.ones(normalized_shape, **factory_kwargs) + if bias: + self.bias = candle.zeros(normalized_shape, **factory_kwargs) + else: + self.bias = None + + def forward(self, input: Tensor) -> Tensor: + mean_x = input.sum_keepdim(2) / float(self.normalized_shape[-1]) + x = input.broadcast_sub(mean_x) + norm_x = x.sqr().sum_keepdim(2) / float(self.normalized_shape[-1]) + x_normed = x.broadcast_div((norm_x + self.eps).sqrt()) + x = x_normed.broadcast_mul(self.weight) + + if self.bias: + x = x.broadcast_add(self.bias) + return x + + def extra_repr(self) -> str: + return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) |