summaryrefslogtreecommitdiff
path: root/candle-pyo3/py_src/candle/functional/__init__.pyi
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3/py_src/candle/functional/__init__.pyi')
-rw-r--r--candle-pyo3/py_src/candle/functional/__init__.pyi14
1 files changed, 14 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/functional/__init__.pyi b/candle-pyo3/py_src/candle/functional/__init__.pyi
index a46b6137..6f206e40 100644
--- a/candle-pyo3/py_src/candle/functional/__init__.pyi
+++ b/candle-pyo3/py_src/candle/functional/__init__.pyi
@@ -5,6 +5,13 @@ from candle.typing import _ArrayLike, Device
from candle import Tensor, DType, QTensor
@staticmethod
+def avg_pool2d(tensor: Tensor, ksize: int, stride: int = 1) -> Tensor:
+ """
+ Applies the 2d avg-pool function to a given tensor.#
+ """
+ pass
+
+@staticmethod
def gelu(tensor: Tensor) -> Tensor:
"""
Applies the Gaussian Error Linear Unit (GELU) function to a given tensor.
@@ -12,6 +19,13 @@ def gelu(tensor: Tensor) -> Tensor:
pass
@staticmethod
+def max_pool2d(tensor: Tensor, ksize: int, stride: int = 1) -> Tensor:
+ """
+ Applies the 2d max-pool function to a given tensor.#
+ """
+ pass
+
+@staticmethod
def relu(tensor: Tensor) -> Tensor:
"""
Applies the Rectified Linear Unit (ReLU) function to a given tensor.