diff options
-rw-r--r-- | candle-pyo3/py_src/candle/__init__.py | 6 | ||||
-rw-r--r-- | candle-pyo3/py_src/candle/__init__.pyi | 361 | ||||
-rw-r--r-- | candle-pyo3/py_src/candle/nn/__init__.pyi | 8 | ||||
-rw-r--r-- | candle-pyo3/py_src/candle/utils/__init__.py | 1 | ||||
-rw-r--r-- | candle-pyo3/py_src/candle/utils/__init__.pyi | 25 | ||||
-rw-r--r-- | candle-pyo3/quant-llama.py | 31 | ||||
-rw-r--r-- | candle-pyo3/src/lib.rs | 270 | ||||
-rw-r--r-- | candle-pyo3/stub.py | 23 | ||||
-rw-r--r-- | candle-pyo3/test.py | 9 |
9 files changed, 574 insertions, 160 deletions
diff --git a/candle-pyo3/py_src/candle/__init__.py b/candle-pyo3/py_src/candle/__init__.py index 49c96122..951609cc 100644 --- a/candle-pyo3/py_src/candle/__init__.py +++ b/candle-pyo3/py_src/candle/__init__.py @@ -1 +1,5 @@ -from .candle import *
\ No newline at end of file +from .candle import * + +__doc__ = candle.__doc__ +if hasattr(candle, "__all__"): + __all__ = candle.__all__
\ No newline at end of file diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi index c21e6738..414f0bc4 100644 --- a/candle-pyo3/py_src/candle/__init__.pyi +++ b/candle-pyo3/py_src/candle/__init__.pyi @@ -7,7 +7,7 @@ class bf16(DType): pass @staticmethod -def cat(tensors: List[Tensor], dim: int): +def cat(tensors: List[Tensor], dim: int) -> Tensor: """ Concatenate the tensors across one axis. """ @@ -26,31 +26,35 @@ class i64(DType): pass @staticmethod -def ones(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None): - """ """ +def ones(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor: + """ + Creates a new tensor filled with ones. + """ pass @staticmethod -def rand(shape: Sequence[int], device: Optional[Device] = None): +def rand(shape: Sequence[int], device: Optional[Device] = None) -> Tensor: """ Creates a new tensor with random values. """ pass @staticmethod -def randn(shape: Sequence[int], device: Optional[Device] = None): - """ """ +def randn(shape: Sequence[int], device: Optional[Device] = None) -> Tensor: + """ + Creates a new tensor with random values from a normal distribution. + """ pass @staticmethod -def stack(tensors: List[Tensor], dim: int): +def stack(tensors: List[Tensor], dim: int) -> Tensor: """ Stack the tensors along a new axis. """ pass @staticmethod -def tensor(data: _ArrayLike): +def tensor(data: _ArrayLike) -> Tensor: """ Creates a new tensor from a Python value. The value can be a scalar or array-like object. """ @@ -63,186 +67,309 @@ class u8(DType): pass @staticmethod -def zeros(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None): - """ """ +def zeros(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor: + """ + Creates a new tensor filled with zeros. + """ pass class DType: - pass + """ + A `candle` dtype. + """ class QTensor: - def dequantize(self): - """ """ + """ + A quantized tensor. + """ + + def dequantize(self) -> Tensor: + """ + Dequantizes the tensor. + """ pass @property - def ggml_dtype(self): - """ """ + def ggml_dtype(self) -> str: + """ + Gets the tensors quantized dtype. + """ pass - def matmul_t(self, lhs): - """ """ + def matmul_t(self, lhs: Tensor) -> Tensor: + """ + Performs a quantized matrix multiplication, with the quantized tensor as the right hand side. + """ pass @property - def rank(self): - """ """ + def rank(self) -> int: + """ + Gets the rank of the tensor. + """ pass @property - def shape(self): - """ """ + def shape(self) -> Tuple[int]: + """ + Gets the shape of the tensor. + """ pass class Tensor: - def __init__(data: _ArrayLike): + """ + A `candle` tensor. + """ + + def __init__(self, data: _ArrayLike): pass - def argmax_keepdim(self, dim): - """ """ + def argmax_keepdim(self, dim: int) -> Tensor: + """ + Returns the indices of the maximum value(s) across the selected dimension. + """ pass - def argmin_keepdim(self, dim): - """ """ + def argmin_keepdim(self, dim: int) -> Tensor: + """ + Returns the indices of the minimum value(s) across the selected dimension. + """ pass - def broadcast_add(self, rhs): - """ """ + def broadcast_add(self, rhs: Tensor) -> Tensor: + """ + Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. + """ pass - def broadcast_as(self, shape): - """ """ + def broadcast_as(self, shape: Sequence[int]) -> Tensor: + """ + Broadcasts the tensor to the given shape. + """ pass - def broadcast_div(self, rhs): - """ """ + def broadcast_div(self, rhs: Tensor) -> Tensor: + """ + Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. + """ pass - def broadcast_left(self, shape): - """ """ + def broadcast_left(self, shape: Sequence[int]) -> Tensor: + """ + Broadcasts the tensor to the given shape, adding new dimensions on the left. + """ pass - def broadcast_mul(self, rhs): - """ """ + def broadcast_mul(self, rhs: Tensor) -> Tensor: + """ + Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. + """ pass - def broadcast_sub(self, rhs): - """ """ + def broadcast_sub(self, rhs: Tensor) -> Tensor: + """ + Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. + """ pass - def contiguous(self): - """ """ + def contiguous(self) -> Tensor: + """ + Makes the tensor contiguous in memory. + """ pass - def copy(self): - """ """ + def copy(self) -> Tensor: + """ + Returns a copy of the tensor. + """ pass - def cos(self): - """ """ + def cos(self) -> Tensor: + """ + Performs the `cos` operation on the tensor. + """ pass - def detach(self): - """ """ + def detach(self) -> Tensor: + """ + Detach the tensor from the computation graph. + """ pass @property - def device(self): - """ """ + def device(self) -> Device: + """ + Gets the tensor's device. + """ pass @property - def dtype(self): - """ """ + def dtype(self) -> DType: + """ + Gets the tensor's dtype. + """ pass - def exp(self): - """ """ + def exp(self) -> Tensor: + """ + Performs the `exp` operation on the tensor. + """ pass - def flatten_all(self): - """ """ + def flatten_all(self) -> Tensor: + """ + Flattens the tensor into a 1D tensor. + """ pass - def flatten_from(self, dim): - """ """ + def flatten_from(self, dim: int) -> Tensor: + """ + Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension. + """ pass - def flatten_to(self, dim): - """ """ + def flatten_to(self, dim: int) -> Tensor: + """ + Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive). + """ pass - def get(self, index): - """ """ + def get(self, index: int) -> Tensor: + """ + Gets the value at the specified index. + """ pass - def index_select(self, rhs, dim): - """ """ + def index_select(self, rhs: Tensor, dim: int) -> Tensor: + """ + Select values for the input tensor at the target indexes across the specified dimension. + + The `indexes` is argument is an int tensor with a single dimension. + The output has the same number of dimension as the `self` input. The target dimension of + the output has length the length of `indexes` and the values are taken from `self` using + the index from `indexes`. Other dimensions have the same number of elements as the input + tensor. + """ pass - def is_contiguous(self): - """ """ + def is_contiguous(self) -> bool: + """ + Returns true if the tensor is contiguous in C order. + """ pass - def is_fortran_contiguous(self): - """ """ + def is_fortran_contiguous(self) -> bool: + """ + Returns true if the tensor is contiguous in Fortran order. + """ pass - def log(self): - """ """ + def log(self) -> Tensor: + """ + Performs the `log` operation on the tensor. + """ pass - def matmul(self, rhs): - """ """ + def matmul(self, rhs: Tensor) -> Tensor: + """ + Performs a matrix multiplication between the two tensors. + """ pass - def max_keepdim(self, dim): - """ """ + def max_keepdim(self, dim: int) -> Tensor: + """ + Gathers the maximum value across the selected dimension. + """ pass - def mean_all(self): - """ """ + def mean_all(self) -> Tensor: + """ + Returns the mean of the tensor. + """ pass - def min_keepdim(self, dim): - """ """ + def min_keepdim(self, dim: int) -> Tensor: + """ + Gathers the minimum value across the selected dimension. + """ pass - def narrow(self, dim, start, len): - """ """ + def narrow(self, dim: int, start: int, len: int) -> Tensor: + """ + Returns a new tensor that is a narrowed version of the input, the dimension `dim` + ranges from `start` to `start + len`. + """ pass - def powf(self, p): - """ """ + def powf(self, p: float) -> Tensor: + """ + Performs the `pow` operation on the tensor with the given exponent. + """ pass - def quantize(self, quantized_dtype): - """ """ + def quantize(self, quantized_dtype: str) -> QTensor: + """ + Quantize the tensor. + """ pass @property - def rank(self): - """ """ + def rank(self) -> int: + """ + Gets the tensor's rank. + """ pass - def recip(self): - """ """ + def recip(self) -> Tensor: + """ + Get the `recip` of the tensor. + """ pass - def reshape(self, shape): - """ """ + def reshape(self, shape: Sequence[int]) -> Tensor: + """ + Reshapes the tensor to the given shape. + """ pass @property - def shape(self): + def shape(self) -> Tuple[int]: """ - Gets the tensor shape as a Python tuple. + Gets the tensor's shape. """ pass - def sin(self): - """ """ + def sin(self) -> Tensor: + """ + Performs the `sin` operation on the tensor. + """ pass - def sqr(self): - """ """ + def sqr(self) -> Tensor: + """ + Squares the tensor. + """ pass - def sqrt(self): - """ """ + def sqrt(self) -> Tensor: + """ + Calculates the square root of the tensor. + """ pass - def squeeze(self, dim): - """ """ + def squeeze(self, dim: int) -> Tensor: + """ + Creates a new tensor with the specified dimension removed if its size was one. + """ pass @property - def stride(self): - """ """ + def stride(self) -> Tuple[int]: + """ + Gets the tensor's strides. + """ pass - def sum_all(self): - """ """ + def sum_all(self) -> Tensor: + """ + Returns the sum of the tensor. + """ pass - def sum_keepdim(self, dims): - """ """ + def sum_keepdim(self, dim: Union[int, List[int]]) -> Tensor: + """ + Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions. + """ pass - def t(self): - """ """ + def t(self) -> Tensor: + """ + Transposes the tensor. + """ pass - def to_device(self, device): - """ """ + def to_device(self, device: Union[str, Device]) -> Tensor: + """ + Move the tensor to a new device. + """ pass - def to_dtype(self, dtype): - """ """ + def to_dtype(self, dtype: Union[str, DType]) -> Tensor: + """ + Convert the tensor to a new dtype. + """ pass - def transpose(self, dim1, dim2): - """ """ + def transpose(self, dim1: int, dim2: int) -> Tensor: + """ + Returns a tensor that is a transposed version of the input, the given dimensions are swapped. + """ pass - def unsqueeze(self, dim): - """ """ + def unsqueeze(self, dim: int) -> Tensor: + """ + Creates a new tensor with a dimension of size one inserted at the specified position. + """ pass - def values(self): + def values(self) -> _ArrayLike: """ Gets the tensor's data as a Python scalar or array-like object. """ pass - def where_cond(self, on_true, on_false): - """ """ + def where_cond(self, on_true: Tensor, on_false: Tensor) -> Tensor: + """ + Returns a tensor with the same shape as the input tensor, the values are taken from + `on_true` if the input tensor value is not zero, and `on_false` at the positions where the + input tensor is equal to zero. + """ pass diff --git a/candle-pyo3/py_src/candle/nn/__init__.pyi b/candle-pyo3/py_src/candle/nn/__init__.pyi index 821cd052..01b30fce 100644 --- a/candle-pyo3/py_src/candle/nn/__init__.pyi +++ b/candle-pyo3/py_src/candle/nn/__init__.pyi @@ -2,18 +2,18 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from os import PathLike from candle.typing import _ArrayLike, Device -from candle import Tensor, DType +from candle import Tensor, DType, QTensor @staticmethod -def silu(tensor: Tensor): +def silu(tensor: Tensor) -> Tensor: """ Applies the Sigmoid Linear Unit (SiLU) function to a given tensor. """ pass @staticmethod -def softmax(tensor: Tensor, dim: int): +def softmax(tensor: Tensor, dim: int) -> Tensor: """ - Applies the Softmax function to a given tensor. + Applies the Softmax function to a given tensor.# """ pass diff --git a/candle-pyo3/py_src/candle/utils/__init__.py b/candle-pyo3/py_src/candle/utils/__init__.py index 2ead6d84..62d85dc9 100644 --- a/candle-pyo3/py_src/candle/utils/__init__.py +++ b/candle-pyo3/py_src/candle/utils/__init__.py @@ -8,4 +8,5 @@ has_mkl = utils.has_mkl load_ggml = utils.load_ggml load_gguf = utils.load_gguf load_safetensors = utils.load_safetensors +save_gguf = utils.save_gguf save_safetensors = utils.save_safetensors diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi index 7a0a5231..61964ffc 100644 --- a/candle-pyo3/py_src/candle/utils/__init__.pyi +++ b/candle-pyo3/py_src/candle/utils/__init__.pyi @@ -2,38 +2,38 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence from os import PathLike from candle.typing import _ArrayLike, Device -from candle import Tensor, DType +from candle import Tensor, DType, QTensor @staticmethod -def cuda_is_available(): +def cuda_is_available() -> bool: """ Returns true if the 'cuda' backend is available. """ pass @staticmethod -def get_num_threads(): +def get_num_threads() -> int: """ Returns the number of threads used by the candle. """ pass @staticmethod -def has_accelerate(): +def has_accelerate() -> bool: """ Returns true if candle was compiled with 'accelerate' support. """ pass @staticmethod -def has_mkl(): +def has_mkl() -> bool: """ Returns true if candle was compiled with MKL support. """ pass @staticmethod -def load_ggml(path: Union[str, PathLike]): +def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]: """ Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. @@ -41,7 +41,7 @@ def load_ggml(path: Union[str, PathLike]): pass @staticmethod -def load_gguf(path: Union[str, PathLike]): +def load_gguf(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any]]: """ Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, and the second maps metadata keys to metadata values. @@ -49,14 +49,21 @@ def load_gguf(path: Union[str, PathLike]): pass @staticmethod -def load_safetensors(path: Union[str, PathLike]): +def load_safetensors(path: Union[str, PathLike]) -> Dict[str, Tensor]: """ Loads a safetensors file. Returns a dictionary mapping tensor names to tensors. """ pass @staticmethod -def save_safetensors(path: Union[str, PathLike], tensors: Dict[str, Tensor]): +def save_gguf(path: Union[str, PathLike], tensors: Dict[str, QTensor], metadata: Dict[str, Any]): + """ + Save quanitzed tensors and metadata to a GGUF file. + """ + pass + +@staticmethod +def save_safetensors(path: Union[str, PathLike], tensors: Dict[str, Tensor]) -> None: """ Saves a dictionary of tensors to a safetensors file. """ diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py index 020d525d..46d9ff62 100644 --- a/candle-pyo3/quant-llama.py +++ b/candle-pyo3/quant-llama.py @@ -1,27 +1,28 @@ # This example shows how the candle Python api can be used to replicate llama.cpp. import sys +from typing import Dict, Tuple, Any import candle -from candle.utils import load_ggml,load_gguf +from candle import Tensor, QTensor, utils, nn MAX_SEQ_LEN = 4096 -def masked_fill(on_false, mask, on_true): +def masked_fill(on_false:Tensor, mask:Tensor, on_true:Tensor): shape = mask.shape on_true = candle.tensor(on_true).broadcast_as(shape) return mask.where_cond(on_true, on_false) class RmsNorm: - def __init__(self, qtensor): + def __init__(self, qtensor:QTensor): self.weight = qtensor.dequantize() - def __call__(self, x): + def __call__(self, x:Tensor): b_size, seq_len, hidden_size = x.shape norm_x = x.sqr().sum_keepdim(2) / hidden_size x_normed = x.broadcast_div((norm_x + 1e-5).sqrt()) return x_normed.broadcast_mul(self.weight) class QuantizedLayer: - def __init__(self, layer_idx, hparams, all_tensors, cos_sin): + def __init__(self, layer_idx:int, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor], cos_sin:Tuple[Tensor,Tensor]): p = f"layers.{layer_idx}" self.attention_wq = all_tensors[f"{p}.attention.wq.weight"] self.attention_wk = all_tensors[f"{p}.attention.wk.weight"] @@ -41,7 +42,7 @@ class QuantizedLayer: self.cos = cos_sin[0] self.sin = cos_sin[1] - def __call__(self, x, mask, index_pos): + def __call__(self, x:Tensor, mask:Tensor, index_pos:int): residual = x x = self.attn_norm(x) attn = self.forward_attn(x, mask, index_pos) @@ -51,11 +52,11 @@ class QuantizedLayer: x = self.ffn_norm(x) w1 = self.ffw1.matmul_t(x) w3 = self.ffw3.matmul_t(x) - mlp = self.ffw2.matmul_t(candle.nn.silu(w1) * w3) + mlp = self.ffw2.matmul_t(nn.silu(w1) * w3) return mlp + residual - def forward_attn(self, x, mask, index_pos): + def forward_attn(self, x:Tensor, mask:Tensor, index_pos:int): b_size, seq_len, n_embd = x.shape q = self.attention_wq.matmul_t(x) k = self.attention_wk.matmul_t(x) @@ -80,12 +81,12 @@ class QuantizedLayer: att = q.matmul(k.t()) / self.head_dim**0.5 mask = mask.broadcast_as(att.shape) att = masked_fill(att, mask, float("-inf")) - att = candle.nn.softmax(att, -1) + att = nn.softmax(att, -1) y = att.matmul(v.contiguous()) y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd)) return self.attention_wo.matmul_t(y) - def apply_rotary_emb(self, x, index_pos): + def apply_rotary_emb(self, x:Tensor, index_pos:int): (b_size, n_head, seq_len, n_embd) = x.shape cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1)) sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1)) @@ -107,7 +108,7 @@ def precompute_freqs_cis(hparams, freq_base): return (m.cos(), m.sin()) class QuantizedLlama: - def __init__(self, hparams, all_tensors): + def __init__(self, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor]): self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize() self.norm = RmsNorm(all_tensors["norm.weight"]) self.output = all_tensors["output.weight"] @@ -118,7 +119,7 @@ class QuantizedLlama: layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin) self.layers.append(layer) - def __call__(self, token, index_pos): + def __call__(self, token:Tensor, index_pos:int): b_size, seq_len = token.shape vocab_size, hidden_size = self.tok_embeddings.shape token = token.reshape((b_size * seq_len,)) @@ -135,7 +136,7 @@ class QuantizedLlama: x = self.output.matmul_t(x) return x -def gguf_rename(tensor_name): +def gguf_rename(tensor_name:str): if tensor_name == 'token_embd.weight': return 'tok_embeddings.weight' if tensor_name == 'output_norm.weight': return 'norm.weight' tensor_name = tensor_name.replace('blk.', 'layers.') @@ -155,7 +156,7 @@ def main(): filename = sys.argv[1] print(f"reading model file {filename}") if filename.endswith("gguf"): - all_tensors, metadata = load_gguf(sys.argv[1]) + all_tensors, metadata = utils.load_gguf(sys.argv[1]) vocab = metadata["tokenizer.ggml.tokens"] for i, v in enumerate(vocab): vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ') @@ -175,7 +176,7 @@ def main(): all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() } else: - all_tensors, hparams, vocab = load_ggml(sys.argv[1]) + all_tensors, hparams, vocab = utils.load_ggml(sys.argv[1]) print(hparams) model = QuantizedLlama(hparams, all_tensors) print("model built, starting inference") diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 1df78ec6..55b7a888 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,7 +1,7 @@ #![allow(clippy::redundant_closure_call)] use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; -use pyo3::types::{IntoPyDict, PyTuple}; +use pyo3::types::{IntoPyDict, PyDict, PyTuple}; use pyo3::ToPyObject; use std::sync::Arc; @@ -31,6 +31,7 @@ impl From<PyShape> for ::candle::Shape { #[derive(Clone, Debug)] #[pyclass(name = "Tensor")] +/// A `candle` tensor. struct PyTensor(Tensor); impl std::ops::Deref for PyTensor { @@ -43,6 +44,7 @@ impl std::ops::Deref for PyTensor { #[derive(Clone, Copy, Debug, PartialEq, Eq)] #[pyclass(name = "DType")] +/// A `candle` dtype. struct PyDType(DType); #[pymethods] @@ -197,7 +199,7 @@ trait MapDType { #[pymethods] impl PyTensor { #[new] - #[pyo3(text_signature = "(data:_ArrayLike)")] + #[pyo3(text_signature = "(self, data:_ArrayLike)")] // TODO: Handle arbitrary input dtype and shape. /// Creates a new tensor from a Python value. The value can be a scalar or array-like object. fn new(py: Python<'_>, data: PyObject) -> PyResult<Self> { @@ -239,6 +241,7 @@ impl PyTensor { } /// Gets the tensor's data as a Python scalar or array-like object. + /// &RETURNS&: _ArrayLike fn values(&self, py: Python<'_>) -> PyResult<PyObject> { struct M<'a>(Python<'a>); impl<'a> MapDType for M<'a> { @@ -282,27 +285,36 @@ impl PyTensor { } #[getter] - /// Gets the tensor shape as a Python tuple. + /// Gets the tensor's shape. + /// &RETURNS&: Tuple[int] fn shape(&self, py: Python<'_>) -> PyObject { PyTuple::new(py, self.0.dims()).to_object(py) } #[getter] + /// Gets the tensor's strides. + /// &RETURNS&: Tuple[int] fn stride(&self, py: Python<'_>) -> PyObject { PyTuple::new(py, self.0.stride()).to_object(py) } #[getter] + /// Gets the tensor's dtype. + /// &RETURNS&: DType fn dtype(&self) -> PyDType { PyDType(self.0.dtype()) } #[getter] + /// Gets the tensor's device. + /// &RETURNS&: Device fn device(&self, py: Python<'_>) -> PyObject { PyDevice::from_device(self.0.device()).to_object(py) } #[getter] + /// Gets the tensor's rank. + /// &RETURNS&: int fn rank(&self) -> usize { self.0.rank() } @@ -315,69 +327,117 @@ impl PyTensor { self.__repr__() } + /// Performs the `sin` operation on the tensor. + /// &RETURNS&: Tensor fn sin(&self) -> PyResult<Self> { Ok(PyTensor(self.0.sin().map_err(wrap_err)?)) } + /// Performs the `cos` operation on the tensor. + /// &RETURNS&: Tensor fn cos(&self) -> PyResult<Self> { Ok(PyTensor(self.0.cos().map_err(wrap_err)?)) } + /// Performs the `log` operation on the tensor. + /// &RETURNS&: Tensor fn log(&self) -> PyResult<Self> { Ok(PyTensor(self.0.log().map_err(wrap_err)?)) } + /// Squares the tensor. + /// &RETURNS&: Tensor fn sqr(&self) -> PyResult<Self> { Ok(PyTensor(self.0.sqr().map_err(wrap_err)?)) } + /// Calculates the square root of the tensor. + /// &RETURNS&: Tensor fn sqrt(&self) -> PyResult<Self> { Ok(PyTensor(self.0.sqrt().map_err(wrap_err)?)) } + /// Get the `recip` of the tensor. + /// &RETURNS&: Tensor fn recip(&self) -> PyResult<Self> { Ok(PyTensor(self.0.recip().map_err(wrap_err)?)) } + /// Performs the `exp` operation on the tensor. + /// &RETURNS&: Tensor fn exp(&self) -> PyResult<Self> { Ok(PyTensor(self.0.exp().map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, p:float)")] + /// Performs the `pow` operation on the tensor with the given exponent. + /// &RETURNS&: Tensor fn powf(&self, p: f64) -> PyResult<Self> { Ok(PyTensor(self.0.powf(p).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, rhs:Tensor, dim:int)")] + /// Select values for the input tensor at the target indexes across the specified dimension. + /// + /// The `indexes` is argument is an int tensor with a single dimension. + /// The output has the same number of dimension as the `self` input. The target dimension of + /// the output has length the length of `indexes` and the values are taken from `self` using + /// the index from `indexes`. Other dimensions have the same number of elements as the input + /// tensor. + /// &RETURNS&: Tensor fn index_select(&self, rhs: &Self, dim: i64) -> PyResult<Self> { let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, rhs:Tensor)")] + /// Performs a matrix multiplication between the two tensors. + /// &RETURNS&: Tensor fn matmul(&self, rhs: &Self) -> PyResult<Self> { Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, rhs:Tensor)")] + /// Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. + /// &RETURNS&: Tensor fn broadcast_add(&self, rhs: &Self) -> PyResult<Self> { Ok(PyTensor(self.0.broadcast_add(rhs).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, rhs:Tensor)")] + /// Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. + /// &RETURNS&: Tensor fn broadcast_sub(&self, rhs: &Self) -> PyResult<Self> { Ok(PyTensor(self.0.broadcast_sub(rhs).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, rhs:Tensor)")] + /// Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. + /// &RETURNS&: Tensor fn broadcast_mul(&self, rhs: &Self) -> PyResult<Self> { Ok(PyTensor(self.0.broadcast_mul(rhs).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, rhs:Tensor)")] + /// Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. + /// &RETURNS&: Tensor fn broadcast_div(&self, rhs: &Self) -> PyResult<Self> { Ok(PyTensor(self.0.broadcast_div(rhs).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, on_true:Tensor, on_false:Tensor)")] + /// Returns a tensor with the same shape as the input tensor, the values are taken from + /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the + /// input tensor is equal to zero. + /// &RETURNS&: Tensor fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult<Self> { Ok(PyTensor( self.0.where_cond(on_true, on_false).map_err(wrap_err)?, )) } + /// Add two tensors. + /// &RETURNS&: Tensor fn __add__(&self, rhs: &PyAny) -> PyResult<Self> { let tensor = if let Ok(rhs) = rhs.extract::<Self>() { (&self.0 + &rhs.0).map_err(wrap_err)? @@ -393,6 +453,8 @@ impl PyTensor { self.__add__(rhs) } + /// Multiply two tensors. + /// &RETURNS&: Tensor fn __mul__(&self, rhs: &PyAny) -> PyResult<Self> { let tensor = if let Ok(rhs) = rhs.extract::<Self>() { (&self.0 * &rhs.0).map_err(wrap_err)? @@ -408,6 +470,8 @@ impl PyTensor { self.__mul__(rhs) } + /// Subtract two tensors. + /// &RETURNS&: Tensor fn __sub__(&self, rhs: &PyAny) -> PyResult<Self> { let tensor = if let Ok(rhs) = rhs.extract::<Self>() { (&self.0 - &rhs.0).map_err(wrap_err)? @@ -419,6 +483,8 @@ impl PyTensor { Ok(Self(tensor)) } + /// Divide two tensors. + /// &RETURNS&: Tensor fn __truediv__(&self, rhs: &PyAny) -> PyResult<Self> { let tensor = if let Ok(rhs) = rhs.extract::<Self>() { (&self.0 / &rhs.0).map_err(wrap_err)? @@ -430,62 +496,102 @@ impl PyTensor { Ok(Self(tensor)) } + #[pyo3(text_signature = "(self, shape:Sequence[int])")] + /// Reshapes the tensor to the given shape. + /// &RETURNS&: Tensor fn reshape(&self, shape: PyShape) -> PyResult<Self> { Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, shape:Sequence[int])")] + /// Broadcasts the tensor to the given shape. + /// &RETURNS&: Tensor fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> { Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, shape:Sequence[int])")] + /// Broadcasts the tensor to the given shape, adding new dimensions on the left. + /// &RETURNS&: Tensor fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> { Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim:int)")] + /// Creates a new tensor with the specified dimension removed if its size was one. + /// &RETURNS&: Tensor fn squeeze(&self, dim: i64) -> PyResult<Self> { let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.squeeze(dim).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim:int)")] + /// Creates a new tensor with a dimension of size one inserted at the specified position. + /// &RETURNS&: Tensor fn unsqueeze(&self, dim: usize) -> PyResult<Self> { Ok(PyTensor(self.0.unsqueeze(dim).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, index:int)")] + /// Gets the value at the specified index. + /// &RETURNS&: Tensor fn get(&self, index: i64) -> PyResult<Self> { let index = actual_index(self, 0, index).map_err(wrap_err)?; Ok(PyTensor(self.0.get(index).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim1:int, dim2:int)")] + /// Returns a tensor that is a transposed version of the input, the given dimensions are swapped. + /// &RETURNS&: Tensor fn transpose(&self, dim1: usize, dim2: usize) -> PyResult<Self> { Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim:int, start:int, len:int)")] + /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` + /// ranges from `start` to `start + len`. + /// &RETURNS&: Tensor fn narrow(&self, dim: i64, start: i64, len: usize) -> PyResult<Self> { let dim = actual_dim(self, dim).map_err(wrap_err)?; let start = actual_index(self, dim, start).map_err(wrap_err)?; Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim:int)")] + /// Returns the indices of the maximum value(s) across the selected dimension. + /// &RETURNS&: Tensor fn argmax_keepdim(&self, dim: i64) -> PyResult<Self> { let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.argmax_keepdim(dim).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim:int)")] + /// Returns the indices of the minimum value(s) across the selected dimension. + /// &RETURNS&: Tensor fn argmin_keepdim(&self, dim: i64) -> PyResult<Self> { let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.argmin_keepdim(dim).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim:int)")] + /// Gathers the maximum value across the selected dimension. + /// &RETURNS&: Tensor fn max_keepdim(&self, dim: i64) -> PyResult<Self> { let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.max_keepdim(dim).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim:int)")] + /// Gathers the minimum value across the selected dimension. + /// &RETURNS&: Tensor fn min_keepdim(&self, dim: i64) -> PyResult<Self> { let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.min_keepdim(dim).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim:Union[int, List[int]])")] + /// Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions. + /// &RETURNS&: Tensor fn sum_keepdim(&self, dims: PyObject, py: Python<'_>) -> PyResult<Self> { let dims = if let Ok(dim) = dims.extract::<usize>(py) { vec![dim] @@ -497,10 +603,14 @@ impl PyTensor { )) } + /// Returns the sum of the tensor. + /// &RETURNS&: Tensor fn sum_all(&self) -> PyResult<Self> { Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?)) } + /// Returns the mean of the tensor. + /// &RETURNS&: Tensor fn mean_all(&self) -> PyResult<Self> { let elements = self.0.elem_count(); let sum = self.0.sum_all().map_err(wrap_err)?; @@ -508,54 +618,83 @@ impl PyTensor { Ok(PyTensor(mean)) } + #[pyo3(text_signature = "(self, dim:int)")] + /// Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension. + /// &RETURNS&: Tensor fn flatten_from(&self, dim: i64) -> PyResult<Self> { let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.flatten_from(dim).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dim:int)")] + ///Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive). + /// &RETURNS&: Tensor fn flatten_to(&self, dim: i64) -> PyResult<Self> { let dim = actual_dim(self, dim).map_err(wrap_err)?; Ok(PyTensor(self.0.flatten_to(dim).map_err(wrap_err)?)) } + /// Flattens the tensor into a 1D tensor. + /// &RETURNS&: Tensor fn flatten_all(&self) -> PyResult<Self> { Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?)) } + /// Transposes the tensor. + /// &RETURNS&: Tensor fn t(&self) -> PyResult<Self> { Ok(PyTensor(self.0.t().map_err(wrap_err)?)) } + /// Makes the tensor contiguous in memory. + /// &RETURNS&: Tensor fn contiguous(&self) -> PyResult<Self> { Ok(PyTensor(self.0.contiguous().map_err(wrap_err)?)) } + /// Returns true if the tensor is contiguous in C order. + /// &RETURNS&: bool fn is_contiguous(&self) -> bool { self.0.is_contiguous() } + /// Returns true if the tensor is contiguous in Fortran order. + /// &RETURNS&: bool fn is_fortran_contiguous(&self) -> bool { self.0.is_fortran_contiguous() } + /// Detach the tensor from the computation graph. + /// &RETURNS&: Tensor fn detach(&self) -> PyResult<Self> { Ok(PyTensor(self.0.detach().map_err(wrap_err)?)) } + /// Returns a copy of the tensor. + /// &RETURNS&: Tensor fn copy(&self) -> PyResult<Self> { Ok(PyTensor(self.0.copy().map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, dtype:Union[str,DType])")] + /// Convert the tensor to a new dtype. + /// &RETURNS&: Tensor fn to_dtype(&self, dtype: PyObject, py: Python<'_>) -> PyResult<Self> { let dtype = PyDType::from_pyobject(dtype, py)?; Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, device:Union[str,Device])")] + /// Move the tensor to a new device. + /// &RETURNS&: Tensor fn to_device(&self, device: PyDevice) -> PyResult<Self> { let device = device.as_device()?; Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?)) } + #[pyo3(text_signature = "(self, quantized_dtype:str)")] + /// Quantize the tensor. + /// &RETURNS&: QTensor fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> { use ::candle::quantized; let res = match quantized_dtype { @@ -586,6 +725,7 @@ impl PyTensor { #[pyfunction] #[pyo3(text_signature = "(tensors:List[Tensor], dim:int )")] /// Concatenate the tensors across one axis. +/// &RETURNS&: Tensor fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> { if tensors.is_empty() { return Err(PyErr::new::<PyValueError, _>("empty input to cat")); @@ -599,6 +739,7 @@ fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> { #[pyfunction] #[pyo3(text_signature = "(tensors:List[Tensor], dim:int)")] /// Stack the tensors along a new axis. +/// &RETURNS&: Tensor fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> { let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>(); let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?; @@ -608,6 +749,7 @@ fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> { #[pyfunction] #[pyo3(text_signature = "(data:_ArrayLike)")] /// Creates a new tensor from a Python value. The value can be a scalar or array-like object. +/// &RETURNS&: Tensor fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> { PyTensor::new(py, data) } @@ -615,6 +757,7 @@ fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> { #[pyfunction] #[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")] /// Creates a new tensor with random values. +/// &RETURNS&: Tensor fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> { let device = device.unwrap_or(PyDevice::Cpu).as_device()?; let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; @@ -623,6 +766,8 @@ fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<P #[pyfunction] #[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")] +/// Creates a new tensor with random values from a normal distribution. +/// &RETURNS&: Tensor fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> { let device = device.unwrap_or(PyDevice::Cpu).as_device()?; let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; @@ -631,6 +776,8 @@ fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult< #[pyfunction] #[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")] +/// Creates a new tensor filled with ones. +/// &RETURNS&: Tensor fn ones( py: Python<'_>, shape: PyShape, @@ -648,6 +795,8 @@ fn ones( #[pyfunction] #[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")] +/// Creates a new tensor filled with zeros. +/// &RETURNS&: Tensor fn zeros( py: Python<'_>, shape: PyShape, @@ -663,8 +812,9 @@ fn zeros( Ok(PyTensor(tensor)) } -#[derive(Debug)] +#[derive(Debug, Clone)] #[pyclass(name = "QTensor")] +/// A quantized tensor. struct PyQTensor(Arc<QTensor>); impl std::ops::Deref for PyQTensor { @@ -678,16 +828,22 @@ impl std::ops::Deref for PyQTensor { #[pymethods] impl PyQTensor { #[getter] + ///Gets the tensors quantized dtype. + /// &RETURNS&: str fn ggml_dtype(&self) -> String { format!("{:?}", self.0.dtype()) } #[getter] + ///Gets the rank of the tensor. + /// &RETURNS&: int fn rank(&self) -> usize { self.0.rank() } #[getter] + ///Gets the shape of the tensor. + /// &RETURNS&: Tuple[int] fn shape(&self, py: Python<'_>) -> PyObject { PyTuple::new(py, self.0.shape().dims()).to_object(py) } @@ -700,11 +856,16 @@ impl PyQTensor { self.__repr__() } + /// Dequantizes the tensor. + /// &RETURNS&: Tensor fn dequantize(&self) -> PyResult<PyTensor> { let tensor = self.0.dequantize(&Device::Cpu).map_err(wrap_err)?; Ok(PyTensor(tensor)) } + #[pyo3(text_signature = "(self, lhs:Tensor)")] + /// Performs a quantized matrix multiplication, with the quantized tensor as the right hand side. + /// &RETURNS&: Tensor fn matmul_t(&self, lhs: &PyTensor) -> PyResult<PyTensor> { let qmatmul = ::candle::quantized::QMatMul::from_arc(self.0.clone()); let res = qmatmul.forward(lhs).map_err(wrap_err)?; @@ -715,6 +876,7 @@ impl PyQTensor { #[pyfunction] #[pyo3(text_signature = "(path:Union[str,PathLike])")] /// Loads a safetensors file. Returns a dictionary mapping tensor names to tensors. +/// &RETURNS&: Dict[str,Tensor] fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> { let res = ::candle::safetensors::load(path, &Device::Cpu).map_err(wrap_err)?; let res = res @@ -727,6 +889,7 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> { #[pyfunction] #[pyo3(text_signature = "(path:Union[str,PathLike], tensors:Dict[str,Tensor])")] /// Saves a dictionary of tensors to a safetensors file. +/// &RETURNS&: None fn save_safetensors( path: &str, tensors: std::collections::HashMap<String, PyTensor>, @@ -742,6 +905,7 @@ fn save_safetensors( #[pyo3(text_signature = "(path:Union[str,PathLike])")] /// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, /// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. +/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]] fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> { let mut file = std::fs::File::open(path)?; let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?; @@ -776,6 +940,7 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje #[pyo3(text_signature = "(path:Union[str,PathLike])")] /// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, /// and the second maps metadata keys to metadata values. +/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]] fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { use ::candle::quantized::gguf_file; fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> { @@ -825,25 +990,117 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { } #[pyfunction] +#[pyo3( + text_signature = "(path:Union[str,PathLike], tensors:Dict[str,QTensor], metadata:Dict[str,Any])" +)] +/// Save quanitzed tensors and metadata to a GGUF file. +fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) -> PyResult<()> { + use ::candle::quantized::gguf_file; + + fn pyobject_to_gguf_value(v: &PyAny, py: Python<'_>) -> PyResult<gguf_file::Value> { + let v: gguf_file::Value = if let Ok(x) = v.extract::<u8>() { + gguf_file::Value::U8(x) + } else if let Ok(x) = v.extract::<i8>() { + gguf_file::Value::I8(x) + } else if let Ok(x) = v.extract::<u16>() { + gguf_file::Value::U16(x) + } else if let Ok(x) = v.extract::<i16>() { + gguf_file::Value::I16(x) + } else if let Ok(x) = v.extract::<u32>() { + gguf_file::Value::U32(x) + } else if let Ok(x) = v.extract::<i32>() { + gguf_file::Value::I32(x) + } else if let Ok(x) = v.extract::<u64>() { + gguf_file::Value::U64(x) + } else if let Ok(x) = v.extract::<i64>() { + gguf_file::Value::I64(x) + } else if let Ok(x) = v.extract::<f32>() { + gguf_file::Value::F32(x) + } else if let Ok(x) = v.extract::<f64>() { + gguf_file::Value::F64(x) + } else if let Ok(x) = v.extract::<bool>() { + gguf_file::Value::Bool(x) + } else if let Ok(x) = v.extract::<String>() { + gguf_file::Value::String(x) + } else if let Ok(x) = v.extract::<Vec<PyObject>>() { + let x = x + .into_iter() + .map(|f| pyobject_to_gguf_value(f.as_ref(py), py)) + .collect::<PyResult<Vec<_>>>()?; + gguf_file::Value::Array(x) + } else { + return Err(PyErr::new::<PyValueError, _>(format!( + "unsupported type {:?}", + v + ))); + }; + Ok(v) + } + let tensors = tensors + .extract::<&PyDict>(py) + .map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))? + .iter() + .map(|(key, value)| { + Ok(( + key.extract::<String>() + .map_err(|_| PyErr::new::<PyValueError, _>("keys must be strings"))?, + value.extract::<PyQTensor>()?.0, + )) + }) + .collect::<PyResult<Vec<_>>>()?; + + let metadata = metadata + .extract::<&PyDict>(py) + .map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))? + .iter() + .map(|(key, value)| { + Ok(( + key.extract::<String>() + .map_err(|_| PyErr::new::<PyValueError, _>("keys must be strings"))?, + pyobject_to_gguf_value(value, py)?, + )) + }) + .collect::<PyResult<Vec<_>>>()?; + + let converted_metadata: Vec<_> = metadata + .iter() + .map(|(name, value)| (name.as_str(), value)) + .collect(); + + let converted_tensors: Vec<_> = tensors + .iter() + .map(|(name, tensor)| (name.as_str(), tensor.as_ref())) + .collect(); + + let mut file = std::fs::File::create(path)?; + + gguf_file::write(&mut file, &converted_metadata, &converted_tensors).map_err(wrap_err) +} + +#[pyfunction] /// Returns true if the 'cuda' backend is available. +/// &RETURNS&: bool fn cuda_is_available() -> bool { ::candle::utils::cuda_is_available() } #[pyfunction] /// Returns true if candle was compiled with 'accelerate' support. +/// &RETURNS&: bool fn has_accelerate() -> bool { ::candle::utils::has_accelerate() } #[pyfunction] /// Returns true if candle was compiled with MKL support. +/// &RETURNS&: bool fn has_mkl() -> bool { ::candle::utils::has_mkl() } #[pyfunction] /// Returns the number of threads used by the candle. +/// &RETURNS&: int fn get_num_threads() -> usize { ::candle::utils::get_num_threads() } @@ -855,6 +1112,7 @@ fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(has_mkl, m)?)?; m.add_function(wrap_pyfunction!(load_ggml, m)?)?; m.add_function(wrap_pyfunction!(load_gguf, m)?)?; + m.add_function(wrap_pyfunction!(save_gguf, m)?)?; m.add_function(wrap_pyfunction!(load_safetensors, m)?)?; m.add_function(wrap_pyfunction!(save_safetensors, m)?)?; Ok(()) @@ -862,7 +1120,8 @@ fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> { #[pyfunction] #[pyo3(text_signature = "(tensor:Tensor, dim:int)")] -/// Applies the Softmax function to a given tensor. +/// Applies the Softmax function to a given tensor.# +/// &RETURNS&: Tensor fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> { let dim = actual_dim(&tensor, dim).map_err(wrap_err)?; let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_err)?; @@ -872,6 +1131,7 @@ fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> { #[pyfunction] #[pyo3(text_signature = "(tensor:Tensor)")] /// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor. +/// &RETURNS&: Tensor fn silu(tensor: PyTensor) -> PyResult<PyTensor> { let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_err)?; Ok(PyTensor(s)) diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py index b5b9256f..149715c2 100644 --- a/candle-pyo3/stub.py +++ b/candle-pyo3/stub.py @@ -13,8 +13,8 @@ TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union from os import PathLike """ CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n" -CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType\n" - +CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n" +RETURN_TYPE_MARKER = "&RETURNS&: " def do_indent(text: Optional[str], indent: str): @@ -28,11 +28,26 @@ def function(obj, indent:str, text_signature:str=None): text_signature = obj.__text_signature__ text_signature = text_signature.replace("$self", "self").lstrip().rstrip() + doc_string = obj.__doc__ + if doc_string is None: + doc_string = "" + + # Check if we have a return type annotation in the docstring + return_type = None + doc_lines = doc_string.split("\n") + if doc_lines[-1].lstrip().startswith(RETURN_TYPE_MARKER): + # Extract the return type and remove it from the docstring + return_type = doc_lines[-1].lstrip()[len(RETURN_TYPE_MARKER):].strip() + doc_string = "\n".join(doc_lines[:-1]) + string = "" - string += f"{indent}def {obj.__name__}{text_signature}:\n" + if return_type: + string += f"{indent}def {obj.__name__}{text_signature} -> {return_type}:\n" + else: + string += f"{indent}def {obj.__name__}{text_signature}:\n" indent += INDENT string += f'{indent}"""\n' - string += f"{indent}{do_indent(obj.__doc__, indent)}\n" + string += f"{indent}{do_indent(doc_string, indent)}\n" string += f'{indent}"""\n' string += f"{indent}pass\n" string += "\n" diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index c78ffc41..7f24b49d 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -1,5 +1,4 @@ import candle -from candle import Tensor, QTensor t = candle.Tensor(42.0) print(t) @@ -10,7 +9,7 @@ t = candle.Tensor([3.0, 1, 4, 1, 5, 9, 2, 6]) print(t) print(t+t) -t:Tensor = t.reshape([2, 4]) +t = t.reshape([2, 4]) print(t.matmul(t.t())) print(t.to_dtype(candle.u8)) @@ -21,7 +20,7 @@ print(t) print(t.dtype) t = candle.randn((16, 256)) -quant_t:QTensor = t.quantize("q6k") -dequant_t:Tensor = quant_t.dequantize() -diff2:Tensor = (t - dequant_t).sqr() +quant_t = t.quantize("q6k") +dequant_t = quant_t.dequantize() +diff2 = (t - dequant_t).sqr() print(diff2.mean_all()) |