summaryrefslogtreecommitdiff
path: root/candle-pyo3
diff options
context:
space:
mode:
Diffstat (limited to 'candle-pyo3')
-rw-r--r--candle-pyo3/py_src/candle/functional/__init__.py2
-rw-r--r--candle-pyo3/py_src/candle/functional/__init__.pyi14
-rw-r--r--candle-pyo3/src/lib.rs24
3 files changed, 40 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/functional/__init__.py b/candle-pyo3/py_src/candle/functional/__init__.py
index efb246f0..095e54e4 100644
--- a/candle-pyo3/py_src/candle/functional/__init__.py
+++ b/candle-pyo3/py_src/candle/functional/__init__.py
@@ -1,7 +1,9 @@
# Generated content DO NOT EDIT
from .. import functional
+avg_pool2d = functional.avg_pool2d
gelu = functional.gelu
+max_pool2d = functional.max_pool2d
relu = functional.relu
silu = functional.silu
softmax = functional.softmax
diff --git a/candle-pyo3/py_src/candle/functional/__init__.pyi b/candle-pyo3/py_src/candle/functional/__init__.pyi
index a46b6137..6f206e40 100644
--- a/candle-pyo3/py_src/candle/functional/__init__.pyi
+++ b/candle-pyo3/py_src/candle/functional/__init__.pyi
@@ -5,6 +5,13 @@ from candle.typing import _ArrayLike, Device
from candle import Tensor, DType, QTensor
@staticmethod
+def avg_pool2d(tensor: Tensor, ksize: int, stride: int = 1) -> Tensor:
+ """
+ Applies the 2d avg-pool function to a given tensor.#
+ """
+ pass
+
+@staticmethod
def gelu(tensor: Tensor) -> Tensor:
"""
Applies the Gaussian Error Linear Unit (GELU) function to a given tensor.
@@ -12,6 +19,13 @@ def gelu(tensor: Tensor) -> Tensor:
pass
@staticmethod
+def max_pool2d(tensor: Tensor, ksize: int, stride: int = 1) -> Tensor:
+ """
+ Applies the 2d max-pool function to a given tensor.#
+ """
+ pass
+
+@staticmethod
def relu(tensor: Tensor) -> Tensor:
"""
Applies the Rectified Linear Unit (ReLU) function to a given tensor.
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 4d4b5200..02db05e5 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -1225,6 +1225,28 @@ fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> {
}
#[pyfunction]
+#[pyo3(signature = (tensor, ksize, *, stride=1), text_signature = "(tensor:Tensor, ksize:int, stride:int=1)")]
+/// Applies the 2d avg-pool function to a given tensor.#
+/// &RETURNS&: Tensor
+fn avg_pool2d(tensor: PyTensor, ksize: usize, stride: usize) -> PyResult<PyTensor> {
+ let tensor = tensor
+ .avg_pool2d_with_stride(ksize, stride)
+ .map_err(wrap_err)?;
+ Ok(PyTensor(tensor))
+}
+
+#[pyfunction]
+#[pyo3(signature = (tensor, ksize, *, stride=1), text_signature = "(tensor:Tensor, ksize:int, stride:int=1)")]
+/// Applies the 2d max-pool function to a given tensor.#
+/// &RETURNS&: Tensor
+fn max_pool2d(tensor: PyTensor, ksize: usize, stride: usize) -> PyResult<PyTensor> {
+ let tensor = tensor
+ .max_pool2d_with_stride(ksize, stride)
+ .map_err(wrap_err)?;
+ Ok(PyTensor(tensor))
+}
+
+#[pyfunction]
#[pyo3(text_signature = "(tensor:Tensor)")]
/// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
/// &RETURNS&: Tensor
@@ -1263,6 +1285,8 @@ fn tanh(tensor: PyTensor) -> PyResult<PyTensor> {
fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(silu, m)?)?;
m.add_function(wrap_pyfunction!(softmax, m)?)?;
+ m.add_function(wrap_pyfunction!(max_pool2d, m)?)?;
+ m.add_function(wrap_pyfunction!(avg_pool2d, m)?)?;
m.add_function(wrap_pyfunction!(gelu, m)?)?;
m.add_function(wrap_pyfunction!(relu, m)?)?;
m.add_function(wrap_pyfunction!(tanh, m)?)?;