diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-02-13 14:26:32 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-13 14:26:32 +0100 |
commit | ad73e93da2cf7311cb5c5bc39250aa335c5f9b76 (patch) | |
tree | 5b5ea591d0fda870f4499869e3a8feb9718cfebf /candle-pyo3/py_src/candle/nn/module.py | |
parent | 13c67226e68de216b731707067f7e68af0438821 (diff) | |
download | candle-ad73e93da2cf7311cb5c5bc39250aa335c5f9b76.tar.gz candle-ad73e93da2cf7311cb5c5bc39250aa335c5f9b76.tar.bz2 candle-ad73e93da2cf7311cb5c5bc39250aa335c5f9b76.zip |
Detach the tensors on batch-norm eval. (#1702)
* Detach the tensors on batch-norm eval.
* Fix pyo3 bindings.
* Black tweak.
* Formatting.
* Also update the pyo3-onnx formatting.
* Apply black.
Diffstat (limited to 'candle-pyo3/py_src/candle/nn/module.py')
-rw-r--r-- | candle-pyo3/py_src/candle/nn/module.py | 12 |
1 files changed, 4 insertions, 8 deletions
diff --git a/candle-pyo3/py_src/candle/nn/module.py b/candle-pyo3/py_src/candle/nn/module.py index 514d92b8..972d9a91 100644 --- a/candle-pyo3/py_src/candle/nn/module.py +++ b/candle-pyo3/py_src/candle/nn/module.py @@ -204,12 +204,10 @@ class Module: T_destination = TypeVar("T_destination", bound=Dict[str, Any]) @overload - def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: - ... + def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: ... @overload - def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: - ... + def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: ... def state_dict(self, *args, destination=None, prefix="", keep_vars=False): r"""Returns a dictionary containing references to the whole state of the module. @@ -586,12 +584,10 @@ class Module: self: T, device: str = ..., dtype: Optional[Union[DType, str]] = ..., - ) -> T: - ... + ) -> T: ... @overload - def to(self: T, dtype: Union[DType, str]) -> T: - ... + def to(self: T, dtype: Union[DType, str]) -> T: ... def to(self, *args, **kwargs): r"""Moves and/or casts the parameters and buffers. |