diff options
author | Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> | 2023-09-17 23:11:01 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-17 22:11:01 +0100 |
commit | 03e194123d743ca73b10797a315b8d47734735e8 (patch) | |
tree | 66553a66f204fbfa9b85df0923a370966573ce7c /candle-pyo3/py_src/candle/utils | |
parent | c2b866172abaf1d4b8d75273c4f4e28a16d872f0 (diff) | |
download | candle-03e194123d743ca73b10797a315b8d47734735e8.tar.gz candle-03e194123d743ca73b10797a315b8d47734735e8.tar.bz2 candle-03e194123d743ca73b10797a315b8d47734735e8.zip |
Add return types to `*.pyi` stubs (#880)
* Start generating return types
* Finish tensor type hinting
* Add `save_gguf` to `utils`
* Typehint `quant-llama.py`
Diffstat (limited to 'candle-pyo3/py_src/candle/utils')
-rw-r--r-- | candle-pyo3/py_src/candle/utils/__init__.py | 1 | ||||
-rw-r--r-- | candle-pyo3/py_src/candle/utils/__init__.pyi | 25 |
2 files changed, 17 insertions, 9 deletions
diff --git a/candle-pyo3/py_src/candle/utils/__init__.py b/candle-pyo3/py_src/candle/utils/__init__.py index 2ead6d84..62d85dc9 100644 --- a/candle-pyo3/py_src/candle/utils/__init__.py +++ b/candle-pyo3/py_src/candle/utils/__init__.py @@ -8,4 +8,5 @@ has_mkl = utils.has_mkl load_ggml = utils.load_ggml load_gguf = utils.load_gguf load_safetensors = utils.load_safetensors +save_gguf = utils.save_gguf save_safetensors = utils.save_safetensors diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi index 7a0a5231..61964ffc 100644 --- a/candle-pyo3/py_src/candle/utils/__init__.pyi +++ b/candle-pyo3/py_src/candle/utils/__init__.pyi @@ -2,38 +2,38 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from os import PathLike from candle.typing import _ArrayLike, Device -from candle import Tensor, DType +from candle import Tensor, DType, QTensor @staticmethod -def cuda_is_available(): +def cuda_is_available() -> bool: """ Returns true if the 'cuda' backend is available. """ pass @staticmethod -def get_num_threads(): +def get_num_threads() -> int: """ Returns the number of threads used by the candle. """ pass @staticmethod -def has_accelerate(): +def has_accelerate() -> bool: """ Returns true if candle was compiled with 'accelerate' support. """ pass @staticmethod -def has_mkl(): +def has_mkl() -> bool: """ Returns true if candle was compiled with MKL support. """ pass @staticmethod -def load_ggml(path: Union[str, PathLike]): +def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]: """ Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. @@ -41,7 +41,7 @@ def load_ggml(path: Union[str, PathLike]): pass @staticmethod -def load_gguf(path: Union[str, PathLike]): +def load_gguf(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any]]: """ Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, and the second maps metadata keys to metadata values. @@ -49,14 +49,21 @@ def load_gguf(path: Union[str, PathLike]): pass @staticmethod -def load_safetensors(path: Union[str, PathLike]): +def load_safetensors(path: Union[str, PathLike]) -> Dict[str, Tensor]: """ Loads a safetensors file. Returns a dictionary mapping tensor names to tensors. """ pass @staticmethod -def save_safetensors(path: Union[str, PathLike], tensors: Dict[str, Tensor]): +def save_gguf(path: Union[str, PathLike], tensors: Dict[str, QTensor], metadata: Dict[str, Any]): + """ + Save quanitzed tensors and metadata to a GGUF file. + """ + pass + +@staticmethod +def save_safetensors(path: Union[str, PathLike], tensors: Dict[str, Tensor]) -> None: """ Saves a dictionary of tensors to a safetensors file. """ |