diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-13 21:18:10 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-13 20:18:10 +0100 |
commit | 2c110ac7d9bd44afc6e56e4becdd396d17d710cd (patch) | |
tree | 19ae395219a8dbd0019c0e412d7f1ce95c051fc3 /candle-pyo3 | |
parent | 75989fc3b7ad06f6216b3aab62a2f3a7fcf4ebba (diff) | |
download | candle-2c110ac7d9bd44afc6e56e4becdd396d17d710cd.tar.gz candle-2c110ac7d9bd44afc6e56e4becdd396d17d710cd.tar.bz2 candle-2c110ac7d9bd44afc6e56e4becdd396d17d710cd.zip |
Add the pooling operators to the pyo3 layer. (#1086)
Diffstat (limited to 'candle-pyo3')
-rw-r--r-- | candle-pyo3/py_src/candle/functional/__init__.py | 2 | ||||
-rw-r--r-- | candle-pyo3/py_src/candle/functional/__init__.pyi | 14 | ||||
-rw-r--r-- | candle-pyo3/src/lib.rs | 24 |
3 files changed, 40 insertions, 0 deletions
diff --git a/candle-pyo3/py_src/candle/functional/__init__.py b/candle-pyo3/py_src/candle/functional/__init__.py index efb246f0..095e54e4 100644 --- a/candle-pyo3/py_src/candle/functional/__init__.py +++ b/candle-pyo3/py_src/candle/functional/__init__.py @@ -1,7 +1,9 @@ # Generated content DO NOT EDIT from .. import functional +avg_pool2d = functional.avg_pool2d gelu = functional.gelu +max_pool2d = functional.max_pool2d relu = functional.relu silu = functional.silu softmax = functional.softmax diff --git a/candle-pyo3/py_src/candle/functional/__init__.pyi b/candle-pyo3/py_src/candle/functional/__init__.pyi index a46b6137..6f206e40 100644 --- a/candle-pyo3/py_src/candle/functional/__init__.pyi +++ b/candle-pyo3/py_src/candle/functional/__init__.pyi @@ -5,6 +5,13 @@ from candle.typing import _ArrayLike, Device from candle import Tensor, DType, QTensor @staticmethod +def avg_pool2d(tensor: Tensor, ksize: int, stride: int = 1) -> Tensor: + """ + Applies the 2d avg-pool function to a given tensor.# + """ + pass + +@staticmethod def gelu(tensor: Tensor) -> Tensor: """ Applies the Gaussian Error Linear Unit (GELU) function to a given tensor. @@ -12,6 +19,13 @@ def gelu(tensor: Tensor) -> Tensor: pass @staticmethod +def max_pool2d(tensor: Tensor, ksize: int, stride: int = 1) -> Tensor: + """ + Applies the 2d max-pool function to a given tensor.# + """ + pass + +@staticmethod def relu(tensor: Tensor) -> Tensor: """ Applies the Rectified Linear Unit (ReLU) function to a given tensor. diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 4d4b5200..02db05e5 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1225,6 +1225,28 @@ fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> { } #[pyfunction] +#[pyo3(signature = (tensor, ksize, *, stride=1), text_signature = "(tensor:Tensor, ksize:int, stride:int=1)")] +/// Applies the 2d avg-pool function to a given tensor.# +/// &RETURNS&: Tensor +fn avg_pool2d(tensor: PyTensor, ksize: usize, stride: usize) -> PyResult<PyTensor> { + let tensor = tensor + .avg_pool2d_with_stride(ksize, stride) + .map_err(wrap_err)?; + Ok(PyTensor(tensor)) +} + +#[pyfunction] +#[pyo3(signature = (tensor, ksize, *, stride=1), text_signature = "(tensor:Tensor, ksize:int, stride:int=1)")] +/// Applies the 2d max-pool function to a given tensor.# +/// &RETURNS&: Tensor +fn max_pool2d(tensor: PyTensor, ksize: usize, stride: usize) -> PyResult<PyTensor> { + let tensor = tensor + .max_pool2d_with_stride(ksize, stride) + .map_err(wrap_err)?; + Ok(PyTensor(tensor)) +} + +#[pyfunction] #[pyo3(text_signature = "(tensor:Tensor)")] /// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor. /// &RETURNS&: Tensor @@ -1263,6 +1285,8 @@ fn tanh(tensor: PyTensor) -> PyResult<PyTensor> { fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(silu, m)?)?; m.add_function(wrap_pyfunction!(softmax, m)?)?; + m.add_function(wrap_pyfunction!(max_pool2d, m)?)?; + m.add_function(wrap_pyfunction!(avg_pool2d, m)?)?; m.add_function(wrap_pyfunction!(gelu, m)?)?; m.add_function(wrap_pyfunction!(relu, m)?)?; m.add_function(wrap_pyfunction!(tanh, m)?)?; |