diff options
author | Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> | 2023-10-06 20:01:07 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-06 19:01:07 +0100 |
commit | 904bbdae65d69aac0c54c29eef744ca5e69c6733 (patch) | |
tree | 8e191c2cb8cac91d76d2bb9875a60d4ccfe9dbf5 /candle-pyo3/py_src/candle/functional | |
parent | b0442eff8a696d1faba10e23ba645eb11e385116 (diff) | |
download | candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.gz candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.bz2 candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.zip |
Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations
* Add `state_dict` and `load_state_dict` functionality
* Move modules around and create `candle.nn.Linear`
* Add `nn.Embedding` and `nn.LayerNorm`
* Add BERT implementation
* Batch q-matmul
* Automatically dequantize `QTensors` if a `Tensor` is expected
* Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality
* Unittests for `Module`, `Tensor` and `candle.utils`
* Add `pytorch` like slicing to `Tensor`
* Cleanup and BERT fixes
* `black` formatting + unit-test for `nn.Linear`
* Refactor slicing implementation
Diffstat (limited to 'candle-pyo3/py_src/candle/functional')
-rw-r--r-- | candle-pyo3/py_src/candle/functional/__init__.py | 8 | ||||
-rw-r--r-- | candle-pyo3/py_src/candle/functional/__init__.pyi | 40 |
2 files changed, 48 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/functional/__init__.py b/candle-pyo3/py_src/candle/functional/__init__.py new file mode 100644 index 00000000..efb246f0 --- /dev/null +++ b/candle-pyo3/py_src/candle/functional/__init__.py @@ -0,0 +1,8 @@ +# Generated content DO NOT EDIT +from .. import functional + +gelu = functional.gelu +relu = functional.relu +silu = functional.silu +softmax = functional.softmax +tanh = functional.tanh diff --git a/candle-pyo3/py_src/candle/functional/__init__.pyi b/candle-pyo3/py_src/candle/functional/__init__.pyi new file mode 100644 index 00000000..a46b6137 --- /dev/null +++ b/candle-pyo3/py_src/candle/functional/__init__.pyi @@ -0,0 +1,40 @@ +# Generated content DO NOT EDIT +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence +from os import PathLike +from candle.typing import _ArrayLike, Device +from candle import Tensor, DType, QTensor + +@staticmethod +def gelu(tensor: Tensor) -> Tensor: + """ + Applies the Gaussian Error Linear Unit (GELU) function to a given tensor. + """ + pass + +@staticmethod +def relu(tensor: Tensor) -> Tensor: + """ + Applies the Rectified Linear Unit (ReLU) function to a given tensor. + """ + pass + +@staticmethod +def silu(tensor: Tensor) -> Tensor: + """ + Applies the Sigmoid Linear Unit (SiLU) function to a given tensor. + """ + pass + +@staticmethod +def softmax(tensor: Tensor, dim: int) -> Tensor: + """ + Applies the Softmax function to a given tensor.# + """ + pass + +@staticmethod +def tanh(tensor: Tensor) -> Tensor: + """ + Applies the tanh function to a given tensor. + """ + pass |