summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-pyo3/.gitignore160
-rw-r--r--candle-pyo3/Cargo.toml1
-rw-r--r--candle-pyo3/README.md21
-rw-r--r--candle-pyo3/py_src/candle/__init__.py1
-rw-r--r--candle-pyo3/py_src/candle/__init__.pyi248
-rw-r--r--candle-pyo3/py_src/candle/nn/__init__.py5
-rw-r--r--candle-pyo3/py_src/candle/nn/__init__.pyi19
-rw-r--r--candle-pyo3/py_src/candle/typing/__init__.py16
-rw-r--r--candle-pyo3/py_src/candle/utils/__init__.py11
-rw-r--r--candle-pyo3/py_src/candle/utils/__init__.pyi63
-rw-r--r--candle-pyo3/pyproject.toml30
-rw-r--r--candle-pyo3/quant-llama.py7
-rw-r--r--candle-pyo3/src/lib.rs89
-rw-r--r--candle-pyo3/stub.py217
-rw-r--r--candle-pyo3/test.py9
15 files changed, 857 insertions, 40 deletions
diff --git a/candle-pyo3/.gitignore b/candle-pyo3/.gitignore
new file mode 100644
index 00000000..68bc17f9
--- /dev/null
+++ b/candle-pyo3/.gitignore
@@ -0,0 +1,160 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml
index 1d431cfb..c96681bd 100644
--- a/candle-pyo3/Cargo.toml
+++ b/candle-pyo3/Cargo.toml
@@ -12,7 +12,6 @@ readme = "README.md"
[lib]
name = "candle"
crate-type = ["cdylib"]
-doc = false
[dependencies]
candle = { path = "../candle-core", version = "0.2.2", package = "candle-core" }
diff --git a/candle-pyo3/README.md b/candle-pyo3/README.md
index 07dff468..be6d4f68 100644
--- a/candle-pyo3/README.md
+++ b/candle-pyo3/README.md
@@ -1,7 +1,26 @@
+## Installation
+
From the `candle-pyo3` directory, enable a virtual env where you will want the
candle package to be installed then run.
```bash
-maturin develop
+maturin develop -r
python test.py
```
+
+## Generating Stub Files for Type Hinting
+
+For type hinting support, the `candle-pyo3` package requires `*.pyi` files. You can automatically generate these files using the `stub.py` script.
+
+### Steps:
+1. Install the package using `maturin`.
+2. Generate the stub files by running:
+ ```
+ python stub.py
+ ```
+
+### Validation:
+To ensure that the stub files match the current implementation, execute:
+```
+python stub.py --check
+```
diff --git a/candle-pyo3/py_src/candle/__init__.py b/candle-pyo3/py_src/candle/__init__.py
new file mode 100644
index 00000000..49c96122
--- /dev/null
+++ b/candle-pyo3/py_src/candle/__init__.py
@@ -0,0 +1 @@
+from .candle import * \ No newline at end of file
diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi
new file mode 100644
index 00000000..c21e6738
--- /dev/null
+++ b/candle-pyo3/py_src/candle/__init__.pyi
@@ -0,0 +1,248 @@
+# Generated content DO NOT EDIT
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
+from os import PathLike
+from candle.typing import _ArrayLike, Device
+
+class bf16(DType):
+ pass
+
+@staticmethod
+def cat(tensors: List[Tensor], dim: int):
+ """
+ Concatenate the tensors across one axis.
+ """
+ pass
+
+class f16(DType):
+ pass
+
+class f32(DType):
+ pass
+
+class f64(DType):
+ pass
+
+class i64(DType):
+ pass
+
+@staticmethod
+def ones(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None):
+ """ """
+ pass
+
+@staticmethod
+def rand(shape: Sequence[int], device: Optional[Device] = None):
+ """
+ Creates a new tensor with random values.
+ """
+ pass
+
+@staticmethod
+def randn(shape: Sequence[int], device: Optional[Device] = None):
+ """ """
+ pass
+
+@staticmethod
+def stack(tensors: List[Tensor], dim: int):
+ """
+ Stack the tensors along a new axis.
+ """
+ pass
+
+@staticmethod
+def tensor(data: _ArrayLike):
+ """
+ Creates a new tensor from a Python value. The value can be a scalar or array-like object.
+ """
+ pass
+
+class u32(DType):
+ pass
+
+class u8(DType):
+ pass
+
+@staticmethod
+def zeros(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None):
+ """ """
+ pass
+
+class DType:
+ pass
+
+class QTensor:
+ def dequantize(self):
+ """ """
+ pass
+ @property
+ def ggml_dtype(self):
+ """ """
+ pass
+ def matmul_t(self, lhs):
+ """ """
+ pass
+ @property
+ def rank(self):
+ """ """
+ pass
+ @property
+ def shape(self):
+ """ """
+ pass
+
+class Tensor:
+ def __init__(data: _ArrayLike):
+ pass
+ def argmax_keepdim(self, dim):
+ """ """
+ pass
+ def argmin_keepdim(self, dim):
+ """ """
+ pass
+ def broadcast_add(self, rhs):
+ """ """
+ pass
+ def broadcast_as(self, shape):
+ """ """
+ pass
+ def broadcast_div(self, rhs):
+ """ """
+ pass
+ def broadcast_left(self, shape):
+ """ """
+ pass
+ def broadcast_mul(self, rhs):
+ """ """
+ pass
+ def broadcast_sub(self, rhs):
+ """ """
+ pass
+ def contiguous(self):
+ """ """
+ pass
+ def copy(self):
+ """ """
+ pass
+ def cos(self):
+ """ """
+ pass
+ def detach(self):
+ """ """
+ pass
+ @property
+ def device(self):
+ """ """
+ pass
+ @property
+ def dtype(self):
+ """ """
+ pass
+ def exp(self):
+ """ """
+ pass
+ def flatten_all(self):
+ """ """
+ pass
+ def flatten_from(self, dim):
+ """ """
+ pass
+ def flatten_to(self, dim):
+ """ """
+ pass
+ def get(self, index):
+ """ """
+ pass
+ def index_select(self, rhs, dim):
+ """ """
+ pass
+ def is_contiguous(self):
+ """ """
+ pass
+ def is_fortran_contiguous(self):
+ """ """
+ pass
+ def log(self):
+ """ """
+ pass
+ def matmul(self, rhs):
+ """ """
+ pass
+ def max_keepdim(self, dim):
+ """ """
+ pass
+ def mean_all(self):
+ """ """
+ pass
+ def min_keepdim(self, dim):
+ """ """
+ pass
+ def narrow(self, dim, start, len):
+ """ """
+ pass
+ def powf(self, p):
+ """ """
+ pass
+ def quantize(self, quantized_dtype):
+ """ """
+ pass
+ @property
+ def rank(self):
+ """ """
+ pass
+ def recip(self):
+ """ """
+ pass
+ def reshape(self, shape):
+ """ """
+ pass
+ @property
+ def shape(self):
+ """
+ Gets the tensor shape as a Python tuple.
+ """
+ pass
+ def sin(self):
+ """ """
+ pass
+ def sqr(self):
+ """ """
+ pass
+ def sqrt(self):
+ """ """
+ pass
+ def squeeze(self, dim):
+ """ """
+ pass
+ @property
+ def stride(self):
+ """ """
+ pass
+ def sum_all(self):
+ """ """
+ pass
+ def sum_keepdim(self, dims):
+ """ """
+ pass
+ def t(self):
+ """ """
+ pass
+ def to_device(self, device):
+ """ """
+ pass
+ def to_dtype(self, dtype):
+ """ """
+ pass
+ def transpose(self, dim1, dim2):
+ """ """
+ pass
+ def unsqueeze(self, dim):
+ """ """
+ pass
+ def values(self):
+ """
+ Gets the tensor's data as a Python scalar or array-like object.
+ """
+ pass
+ def where_cond(self, on_true, on_false):
+ """ """
+ pass
diff --git a/candle-pyo3/py_src/candle/nn/__init__.py b/candle-pyo3/py_src/candle/nn/__init__.py
new file mode 100644
index 00000000..b8c5cfb7
--- /dev/null
+++ b/candle-pyo3/py_src/candle/nn/__init__.py
@@ -0,0 +1,5 @@
+# Generated content DO NOT EDIT
+from .. import nn
+
+silu = nn.silu
+softmax = nn.softmax
diff --git a/candle-pyo3/py_src/candle/nn/__init__.pyi b/candle-pyo3/py_src/candle/nn/__init__.pyi
new file mode 100644
index 00000000..821cd052
--- /dev/null
+++ b/candle-pyo3/py_src/candle/nn/__init__.pyi
@@ -0,0 +1,19 @@
+# Generated content DO NOT EDIT
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
+from os import PathLike
+from candle.typing import _ArrayLike, Device
+from candle import Tensor, DType
+
+@staticmethod
+def silu(tensor: Tensor):
+ """
+ Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
+ """
+ pass
+
+@staticmethod
+def softmax(tensor: Tensor, dim: int):
+ """
+ Applies the Softmax function to a given tensor.
+ """
+ pass
diff --git a/candle-pyo3/py_src/candle/typing/__init__.py b/candle-pyo3/py_src/candle/typing/__init__.py
new file mode 100644
index 00000000..ea85d2a3
--- /dev/null
+++ b/candle-pyo3/py_src/candle/typing/__init__.py
@@ -0,0 +1,16 @@
+from typing import TypeVar, Union, Sequence
+
+_T = TypeVar("_T")
+
+_ArrayLike = Union[
+ _T,
+ Sequence[_T],
+ Sequence[Sequence[_T]],
+ Sequence[Sequence[Sequence[_T]]],
+ Sequence[Sequence[Sequence[Sequence[_T]]]],
+]
+
+CPU:str = "cpu"
+CUDA:str = "cuda"
+
+Device = TypeVar("Device", CPU, CUDA) \ No newline at end of file
diff --git a/candle-pyo3/py_src/candle/utils/__init__.py b/candle-pyo3/py_src/candle/utils/__init__.py
new file mode 100644
index 00000000..2ead6d84
--- /dev/null
+++ b/candle-pyo3/py_src/candle/utils/__init__.py
@@ -0,0 +1,11 @@
+# Generated content DO NOT EDIT
+from .. import utils
+
+cuda_is_available = utils.cuda_is_available
+get_num_threads = utils.get_num_threads
+has_accelerate = utils.has_accelerate
+has_mkl = utils.has_mkl
+load_ggml = utils.load_ggml
+load_gguf = utils.load_gguf
+load_safetensors = utils.load_safetensors
+save_safetensors = utils.save_safetensors
diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi
new file mode 100644
index 00000000..7a0a5231
--- /dev/null
+++ b/candle-pyo3/py_src/candle/utils/__init__.pyi
@@ -0,0 +1,63 @@
+# Generated content DO NOT EDIT
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
+from os import PathLike
+from candle.typing import _ArrayLike, Device
+from candle import Tensor, DType
+
+@staticmethod
+def cuda_is_available():
+ """
+ Returns true if the 'cuda' backend is available.
+ """
+ pass
+
+@staticmethod
+def get_num_threads():
+ """
+ Returns the number of threads used by the candle.
+ """
+ pass
+
+@staticmethod
+def has_accelerate():
+ """
+ Returns true if candle was compiled with 'accelerate' support.
+ """
+ pass
+
+@staticmethod
+def has_mkl():
+ """
+ Returns true if candle was compiled with MKL support.
+ """
+ pass
+
+@staticmethod
+def load_ggml(path: Union[str, PathLike]):
+ """
+ Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
+ a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
+ """
+ pass
+
+@staticmethod
+def load_gguf(path: Union[str, PathLike]):
+ """
+ Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
+ and the second maps metadata keys to metadata values.
+ """
+ pass
+
+@staticmethod
+def load_safetensors(path: Union[str, PathLike]):
+ """
+ Loads a safetensors file. Returns a dictionary mapping tensor names to tensors.
+ """
+ pass
+
+@staticmethod
+def save_safetensors(path: Union[str, PathLike], tensors: Dict[str, Tensor]):
+ """
+ Saves a dictionary of tensors to a safetensors file.
+ """
+ pass
diff --git a/candle-pyo3/pyproject.toml b/candle-pyo3/pyproject.toml
new file mode 100644
index 00000000..b4e372d7
--- /dev/null
+++ b/candle-pyo3/pyproject.toml
@@ -0,0 +1,30 @@
+[project]
+name = 'candle-pyo3'
+requires-python = '>=3.7'
+authors = [
+ {name = 'Laurent Mazare', email = ''},
+]
+
+dynamic = [
+ 'description',
+ 'license',
+ 'readme',
+]
+
+[project.urls]
+Homepage = 'https://github.com/huggingface/candle'
+Source = 'https://github.com/huggingface/candle'
+
+[build-system]
+requires = ["maturin>=1.0,<2.0"]
+build-backend = "maturin"
+
+[tool.maturin]
+python-source = "py_src"
+module-name = "candle.candle"
+bindings = 'pyo3'
+features = ["pyo3/extension-module"]
+
+[tool.black]
+line-length = 119
+target-version = ['py35'] \ No newline at end of file
diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py
index 0f7a51c6..020d525d 100644
--- a/candle-pyo3/quant-llama.py
+++ b/candle-pyo3/quant-llama.py
@@ -1,6 +1,7 @@
# This example shows how the candle Python api can be used to replicate llama.cpp.
import sys
import candle
+from candle.utils import load_ggml,load_gguf
MAX_SEQ_LEN = 4096
@@ -154,7 +155,7 @@ def main():
filename = sys.argv[1]
print(f"reading model file {filename}")
if filename.endswith("gguf"):
- all_tensors, metadata = candle.load_gguf(sys.argv[1])
+ all_tensors, metadata = load_gguf(sys.argv[1])
vocab = metadata["tokenizer.ggml.tokens"]
for i, v in enumerate(vocab):
vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ')
@@ -168,13 +169,13 @@ def main():
'n_head_kv': metadata['llama.attention.head_count_kv'],
'n_layer': metadata['llama.block_count'],
'n_rot': metadata['llama.rope.dimension_count'],
- 'rope_freq': metadata['llama.rope.freq_base'],
+ 'rope_freq': metadata.get('llama.rope.freq_base', 10000.),
'ftype': metadata['general.file_type'],
}
all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() }
else:
- all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1])
+ all_tensors, hparams, vocab = load_ggml(sys.argv[1])
print(hparams)
model = QuantizedLlama(hparams, all_tensors)
print("model built, starting inference")
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index eddc0fda..1df78ec6 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -197,38 +197,40 @@ trait MapDType {
#[pymethods]
impl PyTensor {
#[new]
+ #[pyo3(text_signature = "(data:_ArrayLike)")]
// TODO: Handle arbitrary input dtype and shape.
- fn new(py: Python<'_>, vs: PyObject) -> PyResult<Self> {
+ /// Creates a new tensor from a Python value. The value can be a scalar or array-like object.
+ fn new(py: Python<'_>, data: PyObject) -> PyResult<Self> {
use Device::Cpu;
- let tensor = if let Ok(vs) = vs.extract::<u32>(py) {
+ let tensor = if let Ok(vs) = data.extract::<u32>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<i64>(py) {
+ } else if let Ok(vs) = data.extract::<i64>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<f32>(py) {
+ } else if let Ok(vs) = data.extract::<f32>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
+ } else if let Ok(vs) = data.extract::<Vec<u32>>(py) {
let len = vs.len();
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<i64>>(py) {
+ } else if let Ok(vs) = data.extract::<Vec<i64>>(py) {
let len = vs.len();
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<f32>>(py) {
+ } else if let Ok(vs) = data.extract::<Vec<f32>>(py) {
let len = vs.len();
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<Vec<u32>>>(py) {
+ } else if let Ok(vs) = data.extract::<Vec<Vec<u32>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<Vec<i64>>>(py) {
+ } else if let Ok(vs) = data.extract::<Vec<Vec<i64>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<Vec<f32>>>(py) {
+ } else if let Ok(vs) = data.extract::<Vec<Vec<f32>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<u32>>>>(py) {
+ } else if let Ok(vs) = data.extract::<Vec<Vec<Vec<u32>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<i64>>>>(py) {
+ } else if let Ok(vs) = data.extract::<Vec<Vec<Vec<i64>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
- } else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<f32>>>>(py) {
+ } else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else {
- let ty = vs.as_ref(py).get_type();
+ let ty = data.as_ref(py).get_type();
Err(PyTypeError::new_err(format!(
"incorrect type {ty} for tensor"
)))?
@@ -236,7 +238,7 @@ impl PyTensor {
Ok(Self(tensor))
}
- /// Gets the tensor data as a Python value/array/array of array/...
+ /// Gets the tensor's data as a Python scalar or array-like object.
fn values(&self, py: Python<'_>) -> PyResult<PyObject> {
struct M<'a>(Python<'a>);
impl<'a> MapDType for M<'a> {
@@ -280,6 +282,7 @@ impl PyTensor {
}
#[getter]
+ /// Gets the tensor shape as a Python tuple.
fn shape(&self, py: Python<'_>) -> PyObject {
PyTuple::new(py, self.0.dims()).to_object(py)
}
@@ -580,8 +583,9 @@ impl PyTensor {
}
}
-/// Concatenate the tensors across one axis.
#[pyfunction]
+#[pyo3(text_signature = "(tensors:List[Tensor], dim:int )")]
+/// Concatenate the tensors across one axis.
fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
if tensors.is_empty() {
return Err(PyErr::new::<PyValueError, _>("empty input to cat"));
@@ -593,6 +597,8 @@ fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
}
#[pyfunction]
+#[pyo3(text_signature = "(tensors:List[Tensor], dim:int)")]
+/// Stack the tensors along a new axis.
fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?;
@@ -600,12 +606,15 @@ fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
}
#[pyfunction]
-fn tensor(py: Python<'_>, vs: PyObject) -> PyResult<PyTensor> {
- PyTensor::new(py, vs)
+#[pyo3(text_signature = "(data:_ArrayLike)")]
+/// Creates a new tensor from a Python value. The value can be a scalar or array-like object.
+fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
+ PyTensor::new(py, data)
}
#[pyfunction]
-#[pyo3(signature = (shape, *, device=None))]
+#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
+/// Creates a new tensor with random values.
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
@@ -613,7 +622,7 @@ fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<P
}
#[pyfunction]
-#[pyo3(signature = (shape, *, device=None))]
+#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
@@ -621,7 +630,7 @@ fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<
}
#[pyfunction]
-#[pyo3(signature = (shape, *, dtype=None, device=None))]
+#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
fn ones(
py: Python<'_>,
shape: PyShape,
@@ -638,7 +647,7 @@ fn ones(
}
#[pyfunction]
-#[pyo3(signature = (shape, *, dtype=None, device=None))]
+#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
fn zeros(
py: Python<'_>,
shape: PyShape,
@@ -704,6 +713,8 @@ impl PyQTensor {
}
#[pyfunction]
+#[pyo3(text_signature = "(path:Union[str,PathLike])")]
+/// Loads a safetensors file. Returns a dictionary mapping tensor names to tensors.
fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
let res = ::candle::safetensors::load(path, &Device::Cpu).map_err(wrap_err)?;
let res = res
@@ -714,6 +725,8 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
}
#[pyfunction]
+#[pyo3(text_signature = "(path:Union[str,PathLike], tensors:Dict[str,Tensor])")]
+/// Saves a dictionary of tensors to a safetensors file.
fn save_safetensors(
path: &str,
tensors: std::collections::HashMap<String, PyTensor>,
@@ -726,6 +739,9 @@ fn save_safetensors(
}
#[pyfunction]
+#[pyo3(text_signature = "(path:Union[str,PathLike])")]
+/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
+/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
let mut file = std::fs::File::open(path)?;
let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
@@ -757,6 +773,9 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje
}
#[pyfunction]
+#[pyo3(text_signature = "(path:Union[str,PathLike])")]
+/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
+/// and the second maps metadata keys to metadata values.
fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
use ::candle::quantized::gguf_file;
fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> {
@@ -806,21 +825,25 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
}
#[pyfunction]
+/// Returns true if the 'cuda' backend is available.
fn cuda_is_available() -> bool {
::candle::utils::cuda_is_available()
}
#[pyfunction]
+/// Returns true if candle was compiled with 'accelerate' support.
fn has_accelerate() -> bool {
::candle::utils::has_accelerate()
}
#[pyfunction]
+/// Returns true if candle was compiled with MKL support.
fn has_mkl() -> bool {
::candle::utils::has_mkl()
}
#[pyfunction]
+/// Returns the number of threads used by the candle.
fn get_num_threads() -> usize {
::candle::utils::get_num_threads()
}
@@ -830,19 +853,27 @@ fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(get_num_threads, m)?)?;
m.add_function(wrap_pyfunction!(has_accelerate, m)?)?;
m.add_function(wrap_pyfunction!(has_mkl, m)?)?;
+ m.add_function(wrap_pyfunction!(load_ggml, m)?)?;
+ m.add_function(wrap_pyfunction!(load_gguf, m)?)?;
+ m.add_function(wrap_pyfunction!(load_safetensors, m)?)?;
+ m.add_function(wrap_pyfunction!(save_safetensors, m)?)?;
Ok(())
}
#[pyfunction]
-fn softmax(t: PyTensor, dim: i64) -> PyResult<PyTensor> {
- let dim = actual_dim(&t, dim).map_err(wrap_err)?;
- let sm = candle_nn::ops::softmax(&t.0, dim).map_err(wrap_err)?;
+#[pyo3(text_signature = "(tensor:Tensor, dim:int)")]
+/// Applies the Softmax function to a given tensor.
+fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> {
+ let dim = actual_dim(&tensor, dim).map_err(wrap_err)?;
+ let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_err)?;
Ok(PyTensor(sm))
}
#[pyfunction]
-fn silu(t: PyTensor) -> PyResult<PyTensor> {
- let s = candle_nn::ops::silu(&t.0).map_err(wrap_err)?;
+#[pyo3(text_signature = "(tensor:Tensor)")]
+/// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
+fn silu(tensor: PyTensor) -> PyResult<PyTensor> {
+ let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_err)?;
Ok(PyTensor(s))
}
@@ -871,14 +902,10 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add("f32", PyDType(DType::F32))?;
m.add("f64", PyDType(DType::F64))?;
m.add_function(wrap_pyfunction!(cat, m)?)?;
- m.add_function(wrap_pyfunction!(load_ggml, m)?)?;
- m.add_function(wrap_pyfunction!(load_gguf, m)?)?;
- m.add_function(wrap_pyfunction!(load_safetensors, m)?)?;
m.add_function(wrap_pyfunction!(ones, m)?)?;
m.add_function(wrap_pyfunction!(rand, m)?)?;
m.add_function(wrap_pyfunction!(randn, m)?)?;
m.add_function(wrap_pyfunction!(tensor, m)?)?;
- m.add_function(wrap_pyfunction!(save_safetensors, m)?)?;
m.add_function(wrap_pyfunction!(stack, m)?)?;
m.add_function(wrap_pyfunction!(zeros, m)?)?;
Ok(())
diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py
new file mode 100644
index 00000000..b5b9256f
--- /dev/null
+++ b/candle-pyo3/stub.py
@@ -0,0 +1,217 @@
+#See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py
+import argparse
+import inspect
+import os
+from typing import Optional
+import black
+from pathlib import Path
+
+
+INDENT = " " * 4
+GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
+TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
+from os import PathLike
+"""
+CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n"
+CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType\n"
+
+
+
+def do_indent(text: Optional[str], indent: str):
+ if text is None:
+ return ""
+ return text.replace("\n", f"\n{indent}")
+
+
+def function(obj, indent:str, text_signature:str=None):
+ if text_signature is None:
+ text_signature = obj.__text_signature__
+
+ text_signature = text_signature.replace("$self", "self").lstrip().rstrip()
+ string = ""
+ string += f"{indent}def {obj.__name__}{text_signature}:\n"
+ indent += INDENT
+ string += f'{indent}"""\n'
+ string += f"{indent}{do_indent(obj.__doc__, indent)}\n"
+ string += f'{indent}"""\n'
+ string += f"{indent}pass\n"
+ string += "\n"
+ string += "\n"
+ return string
+
+
+def member_sort(member):
+ if inspect.isclass(member):
+ value = 10 + len(inspect.getmro(member))
+ else:
+ value = 1
+ return value
+
+
+def fn_predicate(obj):
+ value = inspect.ismethoddescriptor(obj) or inspect.isbuiltin(obj)
+ if value:
+ return obj.__text_signature__ and not obj.__name__.startswith("_")
+ if inspect.isgetsetdescriptor(obj):
+ return not obj.__name__.startswith("_")
+ return False
+
+
+def get_module_members(module):
+ members = [
+ member
+ for name, member in inspect.getmembers(module)
+ if not name.startswith("_") and not inspect.ismodule(member)
+ ]
+ members.sort(key=member_sort)
+ return members
+
+
+def pyi_file(obj, indent=""):
+ string = ""
+ if inspect.ismodule(obj):
+ string += GENERATED_COMMENT
+ string += TYPING
+ string += CANDLE_SPECIFIC_TYPING
+ if obj.__name__ != "candle.candle":
+ string += CANDLE_TENSOR_IMPORTS
+ members = get_module_members(obj)
+ for member in members:
+ string += pyi_file(member, indent)
+
+ elif inspect.isclass(obj):
+ indent += INDENT
+ mro = inspect.getmro(obj)
+ if len(mro) > 2:
+ inherit = f"({mro[1].__name__})"
+ else:
+ inherit = ""
+ string += f"class {obj.__name__}{inherit}:\n"
+
+ body = ""
+ if obj.__doc__:
+ body += f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n'
+
+ fns = inspect.getmembers(obj, fn_predicate)
+
+ # Init
+ if obj.__text_signature__:
+ body += f"{indent}def __init__{obj.__text_signature__}:\n"
+ body += f"{indent+INDENT}pass\n"
+ body += "\n"
+
+ for (name, fn) in fns:
+ body += pyi_file(fn, indent=indent)
+
+ if not body:
+ body += f"{indent}pass\n"
+
+ string += body
+ string += "\n\n"
+
+ elif inspect.isbuiltin(obj):
+ string += f"{indent}@staticmethod\n"
+ string += function(obj, indent)
+
+ elif inspect.ismethoddescriptor(obj):
+ string += function(obj, indent)
+
+ elif inspect.isgetsetdescriptor(obj):
+ # TODO it would be interesing to add the setter maybe ?
+ string += f"{indent}@property\n"
+ string += function(obj, indent, text_signature="(self)")
+
+ elif obj.__class__.__name__ == "DType":
+ string += f"class {str(obj).lower()}(DType):\n"
+ string += f"{indent+INDENT}pass\n"
+ else:
+ raise Exception(f"Object {obj} is not supported")
+ return string
+
+
+def py_file(module, origin):
+ members = get_module_members(module)
+
+ string = GENERATED_COMMENT
+ string += f"from .. import {origin}\n"
+ string += "\n"
+ for member in members:
+ if hasattr(member, "__name__"):
+ name = member.__name__
+ else:
+ name = str(member)
+ string += f"{name} = {origin}.{name}\n"
+ return string
+
+
+def do_black(content, is_pyi):
+ mode = black.Mode(
+ target_versions={black.TargetVersion.PY35},
+ line_length=119,
+ is_pyi=is_pyi,
+ string_normalization=True,
+ experimental_string_processing=False,
+ )
+ try:
+ return black.format_file_contents(content, fast=True, mode=mode)
+ except black.NothingChanged:
+ return content
+
+
+def write(module, directory, origin, check=False):
+ submodules = [(name, member) for name, member in inspect.getmembers(module) if inspect.ismodule(member)]
+
+ filename = os.path.join(directory, "__init__.pyi")
+ pyi_content = pyi_file(module)
+ pyi_content = do_black(pyi_content, is_pyi=True)
+ os.makedirs(directory, exist_ok=True)
+ if check:
+ with open(filename, "r") as f:
+ data = f.read()
+ assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`"
+ else:
+ with open(filename, "w") as f:
+ f.write(pyi_content)
+
+ filename = os.path.join(directory, "__init__.py")
+ py_content = py_file(module, origin)
+ py_content = do_black(py_content, is_pyi=False)
+ os.makedirs(directory, exist_ok=True)
+
+ is_auto = False
+ if not os.path.exists(filename):
+ is_auto = True
+ else:
+ with open(filename, "r") as f:
+ line = f.readline()
+ if line == GENERATED_COMMENT:
+ is_auto = True
+
+ if is_auto:
+ if check:
+ with open(filename, "r") as f:
+ data = f.read()
+ assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`"
+ else:
+ with open(filename, "w") as f:
+ f.write(py_content)
+
+ for name, submodule in submodules:
+ write(submodule, os.path.join(directory, name), f"{name}", check=check)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--check", action="store_true")
+
+ args = parser.parse_args()
+
+ #Enable execution from the candle and candle-pyo3 directories
+ cwd = Path.cwd()
+ directory = "py_src/candle/"
+ if cwd.name != "candle-pyo3":
+ directory = f"candle-pyo3/{directory}"
+
+ import candle
+
+ write(candle.candle, directory, "candle", check=args.check)
diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py
index 7f24b49d..c78ffc41 100644
--- a/candle-pyo3/test.py
+++ b/candle-pyo3/test.py
@@ -1,4 +1,5 @@
import candle
+from candle import Tensor, QTensor
t = candle.Tensor(42.0)
print(t)
@@ -9,7 +10,7 @@ t = candle.Tensor([3.0, 1, 4, 1, 5, 9, 2, 6])
print(t)
print(t+t)
-t = t.reshape([2, 4])
+t:Tensor = t.reshape([2, 4])
print(t.matmul(t.t()))
print(t.to_dtype(candle.u8))
@@ -20,7 +21,7 @@ print(t)
print(t.dtype)
t = candle.randn((16, 256))
-quant_t = t.quantize("q6k")
-dequant_t = quant_t.dequantize()
-diff2 = (t - dequant_t).sqr()
+quant_t:QTensor = t.quantize("q6k")
+dequant_t:Tensor = quant_t.dequantize()
+diff2:Tensor = (t - dequant_t).sqr()
print(diff2.mean_all())