summaryrefslogtreecommitdiff
path: root/candle-pyo3/py_src/candle/nn/normalization.py
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-10-06 20:01:07 +0200
committerGitHub <noreply@github.com>2023-10-06 19:01:07 +0100
commit904bbdae65d69aac0c54c29eef744ca5e69c6733 (patch)
tree8e191c2cb8cac91d76d2bb9875a60d4ccfe9dbf5 /candle-pyo3/py_src/candle/nn/normalization.py
parentb0442eff8a696d1faba10e23ba645eb11e385116 (diff)
downloadcandle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.gz
candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.bz2
candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.zip
Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations * Add `state_dict` and `load_state_dict` functionality * Move modules around and create `candle.nn.Linear` * Add `nn.Embedding` and `nn.LayerNorm` * Add BERT implementation * Batch q-matmul * Automatically dequantize `QTensors` if a `Tensor` is expected * Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality * Unittests for `Module`, `Tensor` and `candle.utils` * Add `pytorch` like slicing to `Tensor` * Cleanup and BERT fixes * `black` formatting + unit-test for `nn.Linear` * Refactor slicing implementation
Diffstat (limited to 'candle-pyo3/py_src/candle/nn/normalization.py')
-rw-r--r--candle-pyo3/py_src/candle/nn/normalization.py54
1 files changed, 54 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/nn/normalization.py b/candle-pyo3/py_src/candle/nn/normalization.py
new file mode 100644
index 00000000..67510a24
--- /dev/null
+++ b/candle-pyo3/py_src/candle/nn/normalization.py
@@ -0,0 +1,54 @@
+import candle
+from candle import Tensor
+from .module import Module
+from typing import Union, List, Tuple, Optional, Any
+
+_shape_t = Union[int, List[int]]
+import numbers
+
+
+class LayerNorm(Module):
+ r"""Applies Layer Normalization over a mini-batch of inputs as described in
+ the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`
+
+ math::
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
+ """
+ __constants__ = ["normalized_shape", "eps"]
+ normalized_shape: Tuple[int, ...]
+ eps: float
+
+ def __init__(
+ self,
+ normalized_shape: _shape_t,
+ eps: float = 1e-5,
+ bias: bool = True,
+ device=None,
+ dtype=None,
+ ) -> None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ if isinstance(normalized_shape, numbers.Integral):
+ normalized_shape = (normalized_shape,)
+ self.normalized_shape = tuple(normalized_shape)
+ self.eps = eps
+
+ self.weight = candle.ones(normalized_shape, **factory_kwargs)
+ if bias:
+ self.bias = candle.zeros(normalized_shape, **factory_kwargs)
+ else:
+ self.bias = None
+
+ def forward(self, input: Tensor) -> Tensor:
+ mean_x = input.sum_keepdim(2) / float(self.normalized_shape[-1])
+ x = input.broadcast_sub(mean_x)
+ norm_x = x.sqr().sum_keepdim(2) / float(self.normalized_shape[-1])
+ x_normed = x.broadcast_div((norm_x + self.eps).sqrt())
+ x = x_normed.broadcast_mul(self.weight)
+
+ if self.bias:
+ x = x.broadcast_add(self.bias)
+ return x
+
+ def extra_repr(self) -> str:
+ return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__)