summaryrefslogtreecommitdiff
path: root/candle-pyo3/py_src/candle/functional
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-10-06 20:01:07 +0200
committerGitHub <noreply@github.com>2023-10-06 19:01:07 +0100
commit904bbdae65d69aac0c54c29eef744ca5e69c6733 (patch)
tree8e191c2cb8cac91d76d2bb9875a60d4ccfe9dbf5 /candle-pyo3/py_src/candle/functional
parentb0442eff8a696d1faba10e23ba645eb11e385116 (diff)
downloadcandle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.gz
candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.tar.bz2
candle-904bbdae65d69aac0c54c29eef744ca5e69c6733.zip
Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations * Add `state_dict` and `load_state_dict` functionality * Move modules around and create `candle.nn.Linear` * Add `nn.Embedding` and `nn.LayerNorm` * Add BERT implementation * Batch q-matmul * Automatically dequantize `QTensors` if a `Tensor` is expected * Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality * Unittests for `Module`, `Tensor` and `candle.utils` * Add `pytorch` like slicing to `Tensor` * Cleanup and BERT fixes * `black` formatting + unit-test for `nn.Linear` * Refactor slicing implementation
Diffstat (limited to 'candle-pyo3/py_src/candle/functional')
-rw-r--r--candle-pyo3/py_src/candle/functional/__init__.py8
-rw-r--r--candle-pyo3/py_src/candle/functional/__init__.pyi40
2 files changed, 48 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/functional/__init__.py b/candle-pyo3/py_src/candle/functional/__init__.py
new file mode 100644
index 00000000..efb246f0
--- /dev/null
+++ b/candle-pyo3/py_src/candle/functional/__init__.py
@@ -0,0 +1,8 @@
+# Generated content DO NOT EDIT
+from .. import functional
+
+gelu = functional.gelu
+relu = functional.relu
+silu = functional.silu
+softmax = functional.softmax
+tanh = functional.tanh
diff --git a/candle-pyo3/py_src/candle/functional/__init__.pyi b/candle-pyo3/py_src/candle/functional/__init__.pyi
new file mode 100644
index 00000000..a46b6137
--- /dev/null
+++ b/candle-pyo3/py_src/candle/functional/__init__.pyi
@@ -0,0 +1,40 @@
+# Generated content DO NOT EDIT
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
+from os import PathLike
+from candle.typing import _ArrayLike, Device
+from candle import Tensor, DType, QTensor
+
+@staticmethod
+def gelu(tensor: Tensor) -> Tensor:
+ """
+ Applies the Gaussian Error Linear Unit (GELU) function to a given tensor.
+ """
+ pass
+
+@staticmethod
+def relu(tensor: Tensor) -> Tensor:
+ """
+ Applies the Rectified Linear Unit (ReLU) function to a given tensor.
+ """
+ pass
+
+@staticmethod
+def silu(tensor: Tensor) -> Tensor:
+ """
+ Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
+ """
+ pass
+
+@staticmethod
+def softmax(tensor: Tensor, dim: int) -> Tensor:
+ """
+ Applies the Softmax function to a given tensor.#
+ """
+ pass
+
+@staticmethod
+def tanh(tensor: Tensor) -> Tensor:
+ """
+ Applies the tanh function to a given tensor.
+ """
+ pass