summaryrefslogtreecommitdiff
path: root/candle-pyo3/py_src/candle/utils
diff options
context:
space:
mode:
authorLukas Kreussel <65088241+LLukas22@users.noreply.github.com>2023-09-17 23:11:01 +0200
committerGitHub <noreply@github.com>2023-09-17 22:11:01 +0100
commit03e194123d743ca73b10797a315b8d47734735e8 (patch)
tree66553a66f204fbfa9b85df0923a370966573ce7c /candle-pyo3/py_src/candle/utils
parentc2b866172abaf1d4b8d75273c4f4e28a16d872f0 (diff)
downloadcandle-03e194123d743ca73b10797a315b8d47734735e8.tar.gz
candle-03e194123d743ca73b10797a315b8d47734735e8.tar.bz2
candle-03e194123d743ca73b10797a315b8d47734735e8.zip
Add return types to `*.pyi` stubs (#880)
* Start generating return types * Finish tensor type hinting * Add `save_gguf` to `utils` * Typehint `quant-llama.py`
Diffstat (limited to 'candle-pyo3/py_src/candle/utils')
-rw-r--r--candle-pyo3/py_src/candle/utils/__init__.py1
-rw-r--r--candle-pyo3/py_src/candle/utils/__init__.pyi25
2 files changed, 17 insertions, 9 deletions
diff --git a/candle-pyo3/py_src/candle/utils/__init__.py b/candle-pyo3/py_src/candle/utils/__init__.py
index 2ead6d84..62d85dc9 100644
--- a/candle-pyo3/py_src/candle/utils/__init__.py
+++ b/candle-pyo3/py_src/candle/utils/__init__.py
@@ -8,4 +8,5 @@ has_mkl = utils.has_mkl
load_ggml = utils.load_ggml
load_gguf = utils.load_gguf
load_safetensors = utils.load_safetensors
+save_gguf = utils.save_gguf
save_safetensors = utils.save_safetensors
diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi
index 7a0a5231..61964ffc 100644
--- a/candle-pyo3/py_src/candle/utils/__init__.pyi
+++ b/candle-pyo3/py_src/candle/utils/__init__.pyi
@@ -2,38 +2,38 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device
-from candle import Tensor, DType
+from candle import Tensor, DType, QTensor
@staticmethod
-def cuda_is_available():
+def cuda_is_available() -> bool:
"""
Returns true if the 'cuda' backend is available.
"""
pass
@staticmethod
-def get_num_threads():
+def get_num_threads() -> int:
"""
Returns the number of threads used by the candle.
"""
pass
@staticmethod
-def has_accelerate():
+def has_accelerate() -> bool:
"""
Returns true if candle was compiled with 'accelerate' support.
"""
pass
@staticmethod
-def has_mkl():
+def has_mkl() -> bool:
"""
Returns true if candle was compiled with MKL support.
"""
pass
@staticmethod
-def load_ggml(path: Union[str, PathLike]):
+def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
"""
Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
@@ -41,7 +41,7 @@ def load_ggml(path: Union[str, PathLike]):
pass
@staticmethod
-def load_gguf(path: Union[str, PathLike]):
+def load_gguf(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
"""
Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
and the second maps metadata keys to metadata values.
@@ -49,14 +49,21 @@ def load_gguf(path: Union[str, PathLike]):
pass
@staticmethod
-def load_safetensors(path: Union[str, PathLike]):
+def load_safetensors(path: Union[str, PathLike]) -> Dict[str, Tensor]:
"""
Loads a safetensors file. Returns a dictionary mapping tensor names to tensors.
"""
pass
@staticmethod
-def save_safetensors(path: Union[str, PathLike], tensors: Dict[str, Tensor]):
+def save_gguf(path: Union[str, PathLike], tensors: Dict[str, QTensor], metadata: Dict[str, Any]):
+ """
+ Save quanitzed tensors and metadata to a GGUF file.
+ """
+ pass
+
+@staticmethod
+def save_safetensors(path: Union[str, PathLike], tensors: Dict[str, Tensor]) -> None:
"""
Saves a dictionary of tensors to a safetensors file.
"""