summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-pyo3/py_src/candle/__init__.py6
-rw-r--r--candle-pyo3/py_src/candle/__init__.pyi361
-rw-r--r--candle-pyo3/py_src/candle/nn/__init__.pyi8
-rw-r--r--candle-pyo3/py_src/candle/utils/__init__.py1
-rw-r--r--candle-pyo3/py_src/candle/utils/__init__.pyi25
-rw-r--r--candle-pyo3/quant-llama.py31
-rw-r--r--candle-pyo3/src/lib.rs270
-rw-r--r--candle-pyo3/stub.py23
-rw-r--r--candle-pyo3/test.py9
9 files changed, 574 insertions, 160 deletions
diff --git a/candle-pyo3/py_src/candle/__init__.py b/candle-pyo3/py_src/candle/__init__.py
index 49c96122..951609cc 100644
--- a/candle-pyo3/py_src/candle/__init__.py
+++ b/candle-pyo3/py_src/candle/__init__.py
@@ -1 +1,5 @@
-from .candle import * \ No newline at end of file
+from .candle import *
+
+__doc__ = candle.__doc__
+if hasattr(candle, "__all__"):
+ __all__ = candle.__all__ \ No newline at end of file
diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi
index c21e6738..414f0bc4 100644
--- a/candle-pyo3/py_src/candle/__init__.pyi
+++ b/candle-pyo3/py_src/candle/__init__.pyi
@@ -7,7 +7,7 @@ class bf16(DType):
pass
@staticmethod
-def cat(tensors: List[Tensor], dim: int):
+def cat(tensors: List[Tensor], dim: int) -> Tensor:
"""
Concatenate the tensors across one axis.
"""
@@ -26,31 +26,35 @@ class i64(DType):
pass
@staticmethod
-def ones(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None):
- """ """
+def ones(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
+ """
+ Creates a new tensor filled with ones.
+ """
pass
@staticmethod
-def rand(shape: Sequence[int], device: Optional[Device] = None):
+def rand(shape: Sequence[int], device: Optional[Device] = None) -> Tensor:
"""
Creates a new tensor with random values.
"""
pass
@staticmethod
-def randn(shape: Sequence[int], device: Optional[Device] = None):
- """ """
+def randn(shape: Sequence[int], device: Optional[Device] = None) -> Tensor:
+ """
+ Creates a new tensor with random values from a normal distribution.
+ """
pass
@staticmethod
-def stack(tensors: List[Tensor], dim: int):
+def stack(tensors: List[Tensor], dim: int) -> Tensor:
"""
Stack the tensors along a new axis.
"""
pass
@staticmethod
-def tensor(data: _ArrayLike):
+def tensor(data: _ArrayLike) -> Tensor:
"""
Creates a new tensor from a Python value. The value can be a scalar or array-like object.
"""
@@ -63,186 +67,309 @@ class u8(DType):
pass
@staticmethod
-def zeros(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None):
- """ """
+def zeros(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
+ """
+ Creates a new tensor filled with zeros.
+ """
pass
class DType:
- pass
+ """
+ A `candle` dtype.
+ """
class QTensor:
- def dequantize(self):
- """ """
+ """
+ A quantized tensor.
+ """
+
+ def dequantize(self) -> Tensor:
+ """
+ Dequantizes the tensor.
+ """
pass
@property
- def ggml_dtype(self):
- """ """
+ def ggml_dtype(self) -> str:
+ """
+ Gets the tensors quantized dtype.
+ """
pass
- def matmul_t(self, lhs):
- """ """
+ def matmul_t(self, lhs: Tensor) -> Tensor:
+ """
+ Performs a quantized matrix multiplication, with the quantized tensor as the right hand side.
+ """
pass
@property
- def rank(self):
- """ """
+ def rank(self) -> int:
+ """
+ Gets the rank of the tensor.
+ """
pass
@property
- def shape(self):
- """ """
+ def shape(self) -> Tuple[int]:
+ """
+ Gets the shape of the tensor.
+ """
pass
class Tensor:
- def __init__(data: _ArrayLike):
+ """
+ A `candle` tensor.
+ """
+
+ def __init__(self, data: _ArrayLike):
pass
- def argmax_keepdim(self, dim):
- """ """
+ def argmax_keepdim(self, dim: int) -> Tensor:
+ """
+ Returns the indices of the maximum value(s) across the selected dimension.
+ """
pass
- def argmin_keepdim(self, dim):
- """ """
+ def argmin_keepdim(self, dim: int) -> Tensor:
+ """
+ Returns the indices of the minimum value(s) across the selected dimension.
+ """
pass
- def broadcast_add(self, rhs):
- """ """
+ def broadcast_add(self, rhs: Tensor) -> Tensor:
+ """
+ Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
+ """
pass
- def broadcast_as(self, shape):
- """ """
+ def broadcast_as(self, shape: Sequence[int]) -> Tensor:
+ """
+ Broadcasts the tensor to the given shape.
+ """
pass
- def broadcast_div(self, rhs):
- """ """
+ def broadcast_div(self, rhs: Tensor) -> Tensor:
+ """
+ Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
+ """
pass
- def broadcast_left(self, shape):
- """ """
+ def broadcast_left(self, shape: Sequence[int]) -> Tensor:
+ """
+ Broadcasts the tensor to the given shape, adding new dimensions on the left.
+ """
pass
- def broadcast_mul(self, rhs):
- """ """
+ def broadcast_mul(self, rhs: Tensor) -> Tensor:
+ """
+ Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
+ """
pass
- def broadcast_sub(self, rhs):
- """ """
+ def broadcast_sub(self, rhs: Tensor) -> Tensor:
+ """
+ Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
+ """
pass
- def contiguous(self):
- """ """
+ def contiguous(self) -> Tensor:
+ """
+ Makes the tensor contiguous in memory.
+ """
pass
- def copy(self):
- """ """
+ def copy(self) -> Tensor:
+ """
+ Returns a copy of the tensor.
+ """
pass
- def cos(self):
- """ """
+ def cos(self) -> Tensor:
+ """
+ Performs the `cos` operation on the tensor.
+ """
pass
- def detach(self):
- """ """
+ def detach(self) -> Tensor:
+ """
+ Detach the tensor from the computation graph.
+ """
pass
@property
- def device(self):
- """ """
+ def device(self) -> Device:
+ """
+ Gets the tensor's device.
+ """
pass
@property
- def dtype(self):
- """ """
+ def dtype(self) -> DType:
+ """
+ Gets the tensor's dtype.
+ """
pass
- def exp(self):
- """ """
+ def exp(self) -> Tensor:
+ """
+ Performs the `exp` operation on the tensor.
+ """
pass
- def flatten_all(self):
- """ """
+ def flatten_all(self) -> Tensor:
+ """
+ Flattens the tensor into a 1D tensor.
+ """
pass
- def flatten_from(self, dim):
- """ """
+ def flatten_from(self, dim: int) -> Tensor:
+ """
+ Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension.
+ """
pass
- def flatten_to(self, dim):
- """ """
+ def flatten_to(self, dim: int) -> Tensor:
+ """
+ Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive).
+ """
pass
- def get(self, index):
- """ """
+ def get(self, index: int) -> Tensor:
+ """
+ Gets the value at the specified index.
+ """
pass
- def index_select(self, rhs, dim):
- """ """
+ def index_select(self, rhs: Tensor, dim: int) -> Tensor:
+ """
+ Select values for the input tensor at the target indexes across the specified dimension.
+
+ The `indexes` is argument is an int tensor with a single dimension.
+ The output has the same number of dimension as the `self` input. The target dimension of
+ the output has length the length of `indexes` and the values are taken from `self` using
+ the index from `indexes`. Other dimensions have the same number of elements as the input
+ tensor.
+ """
pass
- def is_contiguous(self):
- """ """
+ def is_contiguous(self) -> bool:
+ """
+ Returns true if the tensor is contiguous in C order.
+ """
pass
- def is_fortran_contiguous(self):
- """ """
+ def is_fortran_contiguous(self) -> bool:
+ """
+ Returns true if the tensor is contiguous in Fortran order.
+ """
pass
- def log(self):
- """ """
+ def log(self) -> Tensor:
+ """
+ Performs the `log` operation on the tensor.
+ """
pass
- def matmul(self, rhs):
- """ """
+ def matmul(self, rhs: Tensor) -> Tensor:
+ """
+ Performs a matrix multiplication between the two tensors.
+ """
pass
- def max_keepdim(self, dim):
- """ """
+ def max_keepdim(self, dim: int) -> Tensor:
+ """
+ Gathers the maximum value across the selected dimension.
+ """
pass
- def mean_all(self):
- """ """
+ def mean_all(self) -> Tensor:
+ """
+ Returns the mean of the tensor.
+ """
pass
- def min_keepdim(self, dim):
- """ """
+ def min_keepdim(self, dim: int) -> Tensor:
+ """
+ Gathers the minimum value across the selected dimension.
+ """
pass
- def narrow(self, dim, start, len):
- """ """
+ def narrow(self, dim: int, start: int, len: int) -> Tensor:
+ """
+ Returns a new tensor that is a narrowed version of the input, the dimension `dim`
+ ranges from `start` to `start + len`.
+ """
pass
- def powf(self, p):
- """ """
+ def powf(self, p: float) -> Tensor:
+ """
+ Performs the `pow` operation on the tensor with the given exponent.
+ """
pass
- def quantize(self, quantized_dtype):
- """ """
+ def quantize(self, quantized_dtype: str) -> QTensor:
+ """
+ Quantize the tensor.
+ """
pass
@property
- def rank(self):
- """ """
+ def rank(self) -> int:
+ """
+ Gets the tensor's rank.
+ """
pass
- def recip(self):
- """ """
+ def recip(self) -> Tensor:
+ """
+ Get the `recip` of the tensor.
+ """
pass
- def reshape(self, shape):
- """ """
+ def reshape(self, shape: Sequence[int]) -> Tensor:
+ """
+ Reshapes the tensor to the given shape.
+ """
pass
@property
- def shape(self):
+ def shape(self) -> Tuple[int]:
"""
- Gets the tensor shape as a Python tuple.
+ Gets the tensor's shape.
"""
pass
- def sin(self):
- """ """
+ def sin(self) -> Tensor:
+ """
+ Performs the `sin` operation on the tensor.
+ """
pass
- def sqr(self):
- """ """
+ def sqr(self) -> Tensor:
+ """
+ Squares the tensor.
+ """
pass
- def sqrt(self):
- """ """
+ def sqrt(self) -> Tensor:
+ """
+ Calculates the square root of the tensor.
+ """
pass
- def squeeze(self, dim):
- """ """
+ def squeeze(self, dim: int) -> Tensor:
+ """
+ Creates a new tensor with the specified dimension removed if its size was one.
+ """
pass
@property
- def stride(self):
- """ """
+ def stride(self) -> Tuple[int]:
+ """
+ Gets the tensor's strides.
+ """
pass
- def sum_all(self):
- """ """
+ def sum_all(self) -> Tensor:
+ """
+ Returns the sum of the tensor.
+ """
pass
- def sum_keepdim(self, dims):
- """ """
+ def sum_keepdim(self, dim: Union[int, List[int]]) -> Tensor:
+ """
+ Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions.
+ """
pass
- def t(self):
- """ """
+ def t(self) -> Tensor:
+ """
+ Transposes the tensor.
+ """
pass
- def to_device(self, device):
- """ """
+ def to_device(self, device: Union[str, Device]) -> Tensor:
+ """
+ Move the tensor to a new device.
+ """
pass
- def to_dtype(self, dtype):
- """ """
+ def to_dtype(self, dtype: Union[str, DType]) -> Tensor:
+ """
+ Convert the tensor to a new dtype.
+ """
pass
- def transpose(self, dim1, dim2):
- """ """
+ def transpose(self, dim1: int, dim2: int) -> Tensor:
+ """
+ Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
+ """
pass
- def unsqueeze(self, dim):
- """ """
+ def unsqueeze(self, dim: int) -> Tensor:
+ """
+ Creates a new tensor with a dimension of size one inserted at the specified position.
+ """
pass
- def values(self):
+ def values(self) -> _ArrayLike:
"""
Gets the tensor's data as a Python scalar or array-like object.
"""
pass
- def where_cond(self, on_true, on_false):
- """ """
+ def where_cond(self, on_true: Tensor, on_false: Tensor) -> Tensor:
+ """
+ Returns a tensor with the same shape as the input tensor, the values are taken from
+ `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
+ input tensor is equal to zero.
+ """
pass
diff --git a/candle-pyo3/py_src/candle/nn/__init__.pyi b/candle-pyo3/py_src/candle/nn/__init__.pyi
index 821cd052..01b30fce 100644
--- a/candle-pyo3/py_src/candle/nn/__init__.pyi
+++ b/candle-pyo3/py_src/candle/nn/__init__.pyi
@@ -2,18 +2,18 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device
-from candle import Tensor, DType
+from candle import Tensor, DType, QTensor
@staticmethod
-def silu(tensor: Tensor):
+def silu(tensor: Tensor) -> Tensor:
"""
Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
"""
pass
@staticmethod
-def softmax(tensor: Tensor, dim: int):
+def softmax(tensor: Tensor, dim: int) -> Tensor:
"""
- Applies the Softmax function to a given tensor.
+ Applies the Softmax function to a given tensor.#
"""
pass
diff --git a/candle-pyo3/py_src/candle/utils/__init__.py b/candle-pyo3/py_src/candle/utils/__init__.py
index 2ead6d84..62d85dc9 100644
--- a/candle-pyo3/py_src/candle/utils/__init__.py
+++ b/candle-pyo3/py_src/candle/utils/__init__.py
@@ -8,4 +8,5 @@ has_mkl = utils.has_mkl
load_ggml = utils.load_ggml
load_gguf = utils.load_gguf
load_safetensors = utils.load_safetensors
+save_gguf = utils.save_gguf
save_safetensors = utils.save_safetensors
diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi
index 7a0a5231..61964ffc 100644
--- a/candle-pyo3/py_src/candle/utils/__init__.pyi
+++ b/candle-pyo3/py_src/candle/utils/__init__.pyi
@@ -2,38 +2,38 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device
-from candle import Tensor, DType
+from candle import Tensor, DType, QTensor
@staticmethod
-def cuda_is_available():
+def cuda_is_available() -> bool:
"""
Returns true if the 'cuda' backend is available.
"""
pass
@staticmethod
-def get_num_threads():
+def get_num_threads() -> int:
"""
Returns the number of threads used by the candle.
"""
pass
@staticmethod
-def has_accelerate():
+def has_accelerate() -> bool:
"""
Returns true if candle was compiled with 'accelerate' support.
"""
pass
@staticmethod
-def has_mkl():
+def has_mkl() -> bool:
"""
Returns true if candle was compiled with MKL support.
"""
pass
@staticmethod
-def load_ggml(path: Union[str, PathLike]):
+def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
"""
Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
@@ -41,7 +41,7 @@ def load_ggml(path: Union[str, PathLike]):
pass
@staticmethod
-def load_gguf(path: Union[str, PathLike]):
+def load_gguf(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
"""
Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
and the second maps metadata keys to metadata values.
@@ -49,14 +49,21 @@ def load_gguf(path: Union[str, PathLike]):
pass
@staticmethod
-def load_safetensors(path: Union[str, PathLike]):
+def load_safetensors(path: Union[str, PathLike]) -> Dict[str, Tensor]:
"""
Loads a safetensors file. Returns a dictionary mapping tensor names to tensors.
"""
pass
@staticmethod
-def save_safetensors(path: Union[str, PathLike], tensors: Dict[str, Tensor]):
+def save_gguf(path: Union[str, PathLike], tensors: Dict[str, QTensor], metadata: Dict[str, Any]):
+ """
+ Save quanitzed tensors and metadata to a GGUF file.
+ """
+ pass
+
+@staticmethod
+def save_safetensors(path: Union[str, PathLike], tensors: Dict[str, Tensor]) -> None:
"""
Saves a dictionary of tensors to a safetensors file.
"""
diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py
index 020d525d..46d9ff62 100644
--- a/candle-pyo3/quant-llama.py
+++ b/candle-pyo3/quant-llama.py
@@ -1,27 +1,28 @@
# This example shows how the candle Python api can be used to replicate llama.cpp.
import sys
+from typing import Dict, Tuple, Any
import candle
-from candle.utils import load_ggml,load_gguf
+from candle import Tensor, QTensor, utils, nn
MAX_SEQ_LEN = 4096
-def masked_fill(on_false, mask, on_true):
+def masked_fill(on_false:Tensor, mask:Tensor, on_true:Tensor):
shape = mask.shape
on_true = candle.tensor(on_true).broadcast_as(shape)
return mask.where_cond(on_true, on_false)
class RmsNorm:
- def __init__(self, qtensor):
+ def __init__(self, qtensor:QTensor):
self.weight = qtensor.dequantize()
- def __call__(self, x):
+ def __call__(self, x:Tensor):
b_size, seq_len, hidden_size = x.shape
norm_x = x.sqr().sum_keepdim(2) / hidden_size
x_normed = x.broadcast_div((norm_x + 1e-5).sqrt())
return x_normed.broadcast_mul(self.weight)
class QuantizedLayer:
- def __init__(self, layer_idx, hparams, all_tensors, cos_sin):
+ def __init__(self, layer_idx:int, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor], cos_sin:Tuple[Tensor,Tensor]):
p = f"layers.{layer_idx}"
self.attention_wq = all_tensors[f"{p}.attention.wq.weight"]
self.attention_wk = all_tensors[f"{p}.attention.wk.weight"]
@@ -41,7 +42,7 @@ class QuantizedLayer:
self.cos = cos_sin[0]
self.sin = cos_sin[1]
- def __call__(self, x, mask, index_pos):
+ def __call__(self, x:Tensor, mask:Tensor, index_pos:int):
residual = x
x = self.attn_norm(x)
attn = self.forward_attn(x, mask, index_pos)
@@ -51,11 +52,11 @@ class QuantizedLayer:
x = self.ffn_norm(x)
w1 = self.ffw1.matmul_t(x)
w3 = self.ffw3.matmul_t(x)
- mlp = self.ffw2.matmul_t(candle.nn.silu(w1) * w3)
+ mlp = self.ffw2.matmul_t(nn.silu(w1) * w3)
return mlp + residual
- def forward_attn(self, x, mask, index_pos):
+ def forward_attn(self, x:Tensor, mask:Tensor, index_pos:int):
b_size, seq_len, n_embd = x.shape
q = self.attention_wq.matmul_t(x)
k = self.attention_wk.matmul_t(x)
@@ -80,12 +81,12 @@ class QuantizedLayer:
att = q.matmul(k.t()) / self.head_dim**0.5
mask = mask.broadcast_as(att.shape)
att = masked_fill(att, mask, float("-inf"))
- att = candle.nn.softmax(att, -1)
+ att = nn.softmax(att, -1)
y = att.matmul(v.contiguous())
y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd))
return self.attention_wo.matmul_t(y)
- def apply_rotary_emb(self, x, index_pos):
+ def apply_rotary_emb(self, x:Tensor, index_pos:int):
(b_size, n_head, seq_len, n_embd) = x.shape
cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
@@ -107,7 +108,7 @@ def precompute_freqs_cis(hparams, freq_base):
return (m.cos(), m.sin())
class QuantizedLlama:
- def __init__(self, hparams, all_tensors):
+ def __init__(self, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor]):
self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize()
self.norm = RmsNorm(all_tensors["norm.weight"])
self.output = all_tensors["output.weight"]
@@ -118,7 +119,7 @@ class QuantizedLlama:
layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin)
self.layers.append(layer)
- def __call__(self, token, index_pos):
+ def __call__(self, token:Tensor, index_pos:int):
b_size, seq_len = token.shape
vocab_size, hidden_size = self.tok_embeddings.shape
token = token.reshape((b_size * seq_len,))
@@ -135,7 +136,7 @@ class QuantizedLlama:
x = self.output.matmul_t(x)
return x
-def gguf_rename(tensor_name):
+def gguf_rename(tensor_name:str):
if tensor_name == 'token_embd.weight': return 'tok_embeddings.weight'
if tensor_name == 'output_norm.weight': return 'norm.weight'
tensor_name = tensor_name.replace('blk.', 'layers.')
@@ -155,7 +156,7 @@ def main():
filename = sys.argv[1]
print(f"reading model file {filename}")
if filename.endswith("gguf"):
- all_tensors, metadata = load_gguf(sys.argv[1])
+ all_tensors, metadata = utils.load_gguf(sys.argv[1])
vocab = metadata["tokenizer.ggml.tokens"]
for i, v in enumerate(vocab):
vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ')
@@ -175,7 +176,7 @@ def main():
all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() }
else:
- all_tensors, hparams, vocab = load_ggml(sys.argv[1])
+ all_tensors, hparams, vocab = utils.load_ggml(sys.argv[1])
print(hparams)
model = QuantizedLlama(hparams, all_tensors)
print("model built, starting inference")
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index 1df78ec6..55b7a888 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -1,7 +1,7 @@
#![allow(clippy::redundant_closure_call)]
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
-use pyo3::types::{IntoPyDict, PyTuple};
+use pyo3::types::{IntoPyDict, PyDict, PyTuple};
use pyo3::ToPyObject;
use std::sync::Arc;
@@ -31,6 +31,7 @@ impl From<PyShape> for ::candle::Shape {
#[derive(Clone, Debug)]
#[pyclass(name = "Tensor")]
+/// A `candle` tensor.
struct PyTensor(Tensor);
impl std::ops::Deref for PyTensor {
@@ -43,6 +44,7 @@ impl std::ops::Deref for PyTensor {
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[pyclass(name = "DType")]
+/// A `candle` dtype.
struct PyDType(DType);
#[pymethods]
@@ -197,7 +199,7 @@ trait MapDType {
#[pymethods]
impl PyTensor {
#[new]
- #[pyo3(text_signature = "(data:_ArrayLike)")]
+ #[pyo3(text_signature = "(self, data:_ArrayLike)")]
// TODO: Handle arbitrary input dtype and shape.
/// Creates a new tensor from a Python value. The value can be a scalar or array-like object.
fn new(py: Python<'_>, data: PyObject) -> PyResult<Self> {
@@ -239,6 +241,7 @@ impl PyTensor {
}
/// Gets the tensor's data as a Python scalar or array-like object.
+ /// &RETURNS&: _ArrayLike
fn values(&self, py: Python<'_>) -> PyResult<PyObject> {
struct M<'a>(Python<'a>);
impl<'a> MapDType for M<'a> {
@@ -282,27 +285,36 @@ impl PyTensor {
}
#[getter]
- /// Gets the tensor shape as a Python tuple.
+ /// Gets the tensor's shape.
+ /// &RETURNS&: Tuple[int]
fn shape(&self, py: Python<'_>) -> PyObject {
PyTuple::new(py, self.0.dims()).to_object(py)
}
#[getter]
+ /// Gets the tensor's strides.
+ /// &RETURNS&: Tuple[int]
fn stride(&self, py: Python<'_>) -> PyObject {
PyTuple::new(py, self.0.stride()).to_object(py)
}
#[getter]
+ /// Gets the tensor's dtype.
+ /// &RETURNS&: DType
fn dtype(&self) -> PyDType {
PyDType(self.0.dtype())
}
#[getter]
+ /// Gets the tensor's device.
+ /// &RETURNS&: Device
fn device(&self, py: Python<'_>) -> PyObject {
PyDevice::from_device(self.0.device()).to_object(py)
}
#[getter]
+ /// Gets the tensor's rank.
+ /// &RETURNS&: int
fn rank(&self) -> usize {
self.0.rank()
}
@@ -315,69 +327,117 @@ impl PyTensor {
self.__repr__()
}
+ /// Performs the `sin` operation on the tensor.
+ /// &RETURNS&: Tensor
fn sin(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.sin().map_err(wrap_err)?))
}
+ /// Performs the `cos` operation on the tensor.
+ /// &RETURNS&: Tensor
fn cos(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.cos().map_err(wrap_err)?))
}
+ /// Performs the `log` operation on the tensor.
+ /// &RETURNS&: Tensor
fn log(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.log().map_err(wrap_err)?))
}
+ /// Squares the tensor.
+ /// &RETURNS&: Tensor
fn sqr(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.sqr().map_err(wrap_err)?))
}
+ /// Calculates the square root of the tensor.
+ /// &RETURNS&: Tensor
fn sqrt(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.sqrt().map_err(wrap_err)?))
}
+ /// Get the `recip` of the tensor.
+ /// &RETURNS&: Tensor
fn recip(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.recip().map_err(wrap_err)?))
}
+ /// Performs the `exp` operation on the tensor.
+ /// &RETURNS&: Tensor
fn exp(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.exp().map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, p:float)")]
+ /// Performs the `pow` operation on the tensor with the given exponent.
+ /// &RETURNS&: Tensor
fn powf(&self, p: f64) -> PyResult<Self> {
Ok(PyTensor(self.0.powf(p).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, rhs:Tensor, dim:int)")]
+ /// Select values for the input tensor at the target indexes across the specified dimension.
+ ///
+ /// The `indexes` is argument is an int tensor with a single dimension.
+ /// The output has the same number of dimension as the `self` input. The target dimension of
+ /// the output has length the length of `indexes` and the values are taken from `self` using
+ /// the index from `indexes`. Other dimensions have the same number of elements as the input
+ /// tensor.
+ /// &RETURNS&: Tensor
fn index_select(&self, rhs: &Self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, rhs:Tensor)")]
+ /// Performs a matrix multiplication between the two tensors.
+ /// &RETURNS&: Tensor
fn matmul(&self, rhs: &Self) -> PyResult<Self> {
Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, rhs:Tensor)")]
+ /// Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
+ /// &RETURNS&: Tensor
fn broadcast_add(&self, rhs: &Self) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_add(rhs).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, rhs:Tensor)")]
+ /// Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
+ /// &RETURNS&: Tensor
fn broadcast_sub(&self, rhs: &Self) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_sub(rhs).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, rhs:Tensor)")]
+ /// Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
+ /// &RETURNS&: Tensor
fn broadcast_mul(&self, rhs: &Self) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_mul(rhs).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, rhs:Tensor)")]
+ /// Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
+ /// &RETURNS&: Tensor
fn broadcast_div(&self, rhs: &Self) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_div(rhs).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, on_true:Tensor, on_false:Tensor)")]
+ /// Returns a tensor with the same shape as the input tensor, the values are taken from
+ /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
+ /// input tensor is equal to zero.
+ /// &RETURNS&: Tensor
fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult<Self> {
Ok(PyTensor(
self.0.where_cond(on_true, on_false).map_err(wrap_err)?,
))
}
+ /// Add two tensors.
+ /// &RETURNS&: Tensor
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
(&self.0 + &rhs.0).map_err(wrap_err)?
@@ -393,6 +453,8 @@ impl PyTensor {
self.__add__(rhs)
}
+ /// Multiply two tensors.
+ /// &RETURNS&: Tensor
fn __mul__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
(&self.0 * &rhs.0).map_err(wrap_err)?
@@ -408,6 +470,8 @@ impl PyTensor {
self.__mul__(rhs)
}
+ /// Subtract two tensors.
+ /// &RETURNS&: Tensor
fn __sub__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
(&self.0 - &rhs.0).map_err(wrap_err)?
@@ -419,6 +483,8 @@ impl PyTensor {
Ok(Self(tensor))
}
+ /// Divide two tensors.
+ /// &RETURNS&: Tensor
fn __truediv__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
(&self.0 / &rhs.0).map_err(wrap_err)?
@@ -430,62 +496,102 @@ impl PyTensor {
Ok(Self(tensor))
}
+ #[pyo3(text_signature = "(self, shape:Sequence[int])")]
+ /// Reshapes the tensor to the given shape.
+ /// &RETURNS&: Tensor
fn reshape(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, shape:Sequence[int])")]
+ /// Broadcasts the tensor to the given shape.
+ /// &RETURNS&: Tensor
fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, shape:Sequence[int])")]
+ /// Broadcasts the tensor to the given shape, adding new dimensions on the left.
+ /// &RETURNS&: Tensor
fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim:int)")]
+ /// Creates a new tensor with the specified dimension removed if its size was one.
+ /// &RETURNS&: Tensor
fn squeeze(&self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.squeeze(dim).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim:int)")]
+ /// Creates a new tensor with a dimension of size one inserted at the specified position.
+ /// &RETURNS&: Tensor
fn unsqueeze(&self, dim: usize) -> PyResult<Self> {
Ok(PyTensor(self.0.unsqueeze(dim).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, index:int)")]
+ /// Gets the value at the specified index.
+ /// &RETURNS&: Tensor
fn get(&self, index: i64) -> PyResult<Self> {
let index = actual_index(self, 0, index).map_err(wrap_err)?;
Ok(PyTensor(self.0.get(index).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim1:int, dim2:int)")]
+ /// Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
+ /// &RETURNS&: Tensor
fn transpose(&self, dim1: usize, dim2: usize) -> PyResult<Self> {
Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim:int, start:int, len:int)")]
+ /// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
+ /// ranges from `start` to `start + len`.
+ /// &RETURNS&: Tensor
fn narrow(&self, dim: i64, start: i64, len: usize) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
let start = actual_index(self, dim, start).map_err(wrap_err)?;
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim:int)")]
+ /// Returns the indices of the maximum value(s) across the selected dimension.
+ /// &RETURNS&: Tensor
fn argmax_keepdim(&self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.argmax_keepdim(dim).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim:int)")]
+ /// Returns the indices of the minimum value(s) across the selected dimension.
+ /// &RETURNS&: Tensor
fn argmin_keepdim(&self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.argmin_keepdim(dim).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim:int)")]
+ /// Gathers the maximum value across the selected dimension.
+ /// &RETURNS&: Tensor
fn max_keepdim(&self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.max_keepdim(dim).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim:int)")]
+ /// Gathers the minimum value across the selected dimension.
+ /// &RETURNS&: Tensor
fn min_keepdim(&self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.min_keepdim(dim).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim:Union[int, List[int]])")]
+ /// Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions.
+ /// &RETURNS&: Tensor
fn sum_keepdim(&self, dims: PyObject, py: Python<'_>) -> PyResult<Self> {
let dims = if let Ok(dim) = dims.extract::<usize>(py) {
vec![dim]
@@ -497,10 +603,14 @@ impl PyTensor {
))
}
+ /// Returns the sum of the tensor.
+ /// &RETURNS&: Tensor
fn sum_all(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?))
}
+ /// Returns the mean of the tensor.
+ /// &RETURNS&: Tensor
fn mean_all(&self) -> PyResult<Self> {
let elements = self.0.elem_count();
let sum = self.0.sum_all().map_err(wrap_err)?;
@@ -508,54 +618,83 @@ impl PyTensor {
Ok(PyTensor(mean))
}
+ #[pyo3(text_signature = "(self, dim:int)")]
+ /// Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension.
+ /// &RETURNS&: Tensor
fn flatten_from(&self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.flatten_from(dim).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dim:int)")]
+ ///Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive).
+ /// &RETURNS&: Tensor
fn flatten_to(&self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.flatten_to(dim).map_err(wrap_err)?))
}
+ /// Flattens the tensor into a 1D tensor.
+ /// &RETURNS&: Tensor
fn flatten_all(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?))
}
+ /// Transposes the tensor.
+ /// &RETURNS&: Tensor
fn t(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.t().map_err(wrap_err)?))
}
+ /// Makes the tensor contiguous in memory.
+ /// &RETURNS&: Tensor
fn contiguous(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.contiguous().map_err(wrap_err)?))
}
+ /// Returns true if the tensor is contiguous in C order.
+ /// &RETURNS&: bool
fn is_contiguous(&self) -> bool {
self.0.is_contiguous()
}
+ /// Returns true if the tensor is contiguous in Fortran order.
+ /// &RETURNS&: bool
fn is_fortran_contiguous(&self) -> bool {
self.0.is_fortran_contiguous()
}
+ /// Detach the tensor from the computation graph.
+ /// &RETURNS&: Tensor
fn detach(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.detach().map_err(wrap_err)?))
}
+ /// Returns a copy of the tensor.
+ /// &RETURNS&: Tensor
fn copy(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, dtype:Union[str,DType])")]
+ /// Convert the tensor to a new dtype.
+ /// &RETURNS&: Tensor
fn to_dtype(&self, dtype: PyObject, py: Python<'_>) -> PyResult<Self> {
let dtype = PyDType::from_pyobject(dtype, py)?;
Ok(PyTensor(self.0.to_dtype(dtype.0).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, device:Union[str,Device])")]
+ /// Move the tensor to a new device.
+ /// &RETURNS&: Tensor
fn to_device(&self, device: PyDevice) -> PyResult<Self> {
let device = device.as_device()?;
Ok(PyTensor(self.0.to_device(&device).map_err(wrap_err)?))
}
+ #[pyo3(text_signature = "(self, quantized_dtype:str)")]
+ /// Quantize the tensor.
+ /// &RETURNS&: QTensor
fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
use ::candle::quantized;
let res = match quantized_dtype {
@@ -586,6 +725,7 @@ impl PyTensor {
#[pyfunction]
#[pyo3(text_signature = "(tensors:List[Tensor], dim:int )")]
/// Concatenate the tensors across one axis.
+/// &RETURNS&: Tensor
fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
if tensors.is_empty() {
return Err(PyErr::new::<PyValueError, _>("empty input to cat"));
@@ -599,6 +739,7 @@ fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
#[pyfunction]
#[pyo3(text_signature = "(tensors:List[Tensor], dim:int)")]
/// Stack the tensors along a new axis.
+/// &RETURNS&: Tensor
fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?;
@@ -608,6 +749,7 @@ fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
#[pyfunction]
#[pyo3(text_signature = "(data:_ArrayLike)")]
/// Creates a new tensor from a Python value. The value can be a scalar or array-like object.
+/// &RETURNS&: Tensor
fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
PyTensor::new(py, data)
}
@@ -615,6 +757,7 @@ fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
#[pyfunction]
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
/// Creates a new tensor with random values.
+/// &RETURNS&: Tensor
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
@@ -623,6 +766,8 @@ fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<P
#[pyfunction]
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
+/// Creates a new tensor with random values from a normal distribution.
+/// &RETURNS&: Tensor
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
@@ -631,6 +776,8 @@ fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<
#[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
+/// Creates a new tensor filled with ones.
+/// &RETURNS&: Tensor
fn ones(
py: Python<'_>,
shape: PyShape,
@@ -648,6 +795,8 @@ fn ones(
#[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
+/// Creates a new tensor filled with zeros.
+/// &RETURNS&: Tensor
fn zeros(
py: Python<'_>,
shape: PyShape,
@@ -663,8 +812,9 @@ fn zeros(
Ok(PyTensor(tensor))
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
#[pyclass(name = "QTensor")]
+/// A quantized tensor.
struct PyQTensor(Arc<QTensor>);
impl std::ops::Deref for PyQTensor {
@@ -678,16 +828,22 @@ impl std::ops::Deref for PyQTensor {
#[pymethods]
impl PyQTensor {
#[getter]
+ ///Gets the tensors quantized dtype.
+ /// &RETURNS&: str
fn ggml_dtype(&self) -> String {
format!("{:?}", self.0.dtype())
}
#[getter]
+ ///Gets the rank of the tensor.
+ /// &RETURNS&: int
fn rank(&self) -> usize {
self.0.rank()
}
#[getter]
+ ///Gets the shape of the tensor.
+ /// &RETURNS&: Tuple[int]
fn shape(&self, py: Python<'_>) -> PyObject {
PyTuple::new(py, self.0.shape().dims()).to_object(py)
}
@@ -700,11 +856,16 @@ impl PyQTensor {
self.__repr__()
}
+ /// Dequantizes the tensor.
+ /// &RETURNS&: Tensor
fn dequantize(&self) -> PyResult<PyTensor> {
let tensor = self.0.dequantize(&Device::Cpu).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
+ #[pyo3(text_signature = "(self, lhs:Tensor)")]
+ /// Performs a quantized matrix multiplication, with the quantized tensor as the right hand side.
+ /// &RETURNS&: Tensor
fn matmul_t(&self, lhs: &PyTensor) -> PyResult<PyTensor> {
let qmatmul = ::candle::quantized::QMatMul::from_arc(self.0.clone());
let res = qmatmul.forward(lhs).map_err(wrap_err)?;
@@ -715,6 +876,7 @@ impl PyQTensor {
#[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
/// Loads a safetensors file. Returns a dictionary mapping tensor names to tensors.
+/// &RETURNS&: Dict[str,Tensor]
fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
let res = ::candle::safetensors::load(path, &Device::Cpu).map_err(wrap_err)?;
let res = res
@@ -727,6 +889,7 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
#[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike], tensors:Dict[str,Tensor])")]
/// Saves a dictionary of tensors to a safetensors file.
+/// &RETURNS&: None
fn save_safetensors(
path: &str,
tensors: std::collections::HashMap<String, PyTensor>,
@@ -742,6 +905,7 @@ fn save_safetensors(
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
+/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]]
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
let mut file = std::fs::File::open(path)?;
let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
@@ -776,6 +940,7 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
/// and the second maps metadata keys to metadata values.
+/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]]
fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
use ::candle::quantized::gguf_file;
fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> {
@@ -825,25 +990,117 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
}
#[pyfunction]
+#[pyo3(
+ text_signature = "(path:Union[str,PathLike], tensors:Dict[str,QTensor], metadata:Dict[str,Any])"
+)]
+/// Save quanitzed tensors and metadata to a GGUF file.
+fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) -> PyResult<()> {
+ use ::candle::quantized::gguf_file;
+
+ fn pyobject_to_gguf_value(v: &PyAny, py: Python<'_>) -> PyResult<gguf_file::Value> {
+ let v: gguf_file::Value = if let Ok(x) = v.extract::<u8>() {
+ gguf_file::Value::U8(x)
+ } else if let Ok(x) = v.extract::<i8>() {
+ gguf_file::Value::I8(x)
+ } else if let Ok(x) = v.extract::<u16>() {
+ gguf_file::Value::U16(x)
+ } else if let Ok(x) = v.extract::<i16>() {
+ gguf_file::Value::I16(x)
+ } else if let Ok(x) = v.extract::<u32>() {
+ gguf_file::Value::U32(x)
+ } else if let Ok(x) = v.extract::<i32>() {
+ gguf_file::Value::I32(x)
+ } else if let Ok(x) = v.extract::<u64>() {
+ gguf_file::Value::U64(x)
+ } else if let Ok(x) = v.extract::<i64>() {
+ gguf_file::Value::I64(x)
+ } else if let Ok(x) = v.extract::<f32>() {
+ gguf_file::Value::F32(x)
+ } else if let Ok(x) = v.extract::<f64>() {
+ gguf_file::Value::F64(x)
+ } else if let Ok(x) = v.extract::<bool>() {
+ gguf_file::Value::Bool(x)
+ } else if let Ok(x) = v.extract::<String>() {
+ gguf_file::Value::String(x)
+ } else if let Ok(x) = v.extract::<Vec<PyObject>>() {
+ let x = x
+ .into_iter()
+ .map(|f| pyobject_to_gguf_value(f.as_ref(py), py))
+ .collect::<PyResult<Vec<_>>>()?;
+ gguf_file::Value::Array(x)
+ } else {
+ return Err(PyErr::new::<PyValueError, _>(format!(
+ "unsupported type {:?}",
+ v
+ )));
+ };
+ Ok(v)
+ }
+ let tensors = tensors
+ .extract::<&PyDict>(py)
+ .map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))?
+ .iter()
+ .map(|(key, value)| {
+ Ok((
+ key.extract::<String>()
+ .map_err(|_| PyErr::new::<PyValueError, _>("keys must be strings"))?,
+ value.extract::<PyQTensor>()?.0,
+ ))
+ })
+ .collect::<PyResult<Vec<_>>>()?;
+
+ let metadata = metadata
+ .extract::<&PyDict>(py)
+ .map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))?
+ .iter()
+ .map(|(key, value)| {
+ Ok((
+ key.extract::<String>()
+ .map_err(|_| PyErr::new::<PyValueError, _>("keys must be strings"))?,
+ pyobject_to_gguf_value(value, py)?,
+ ))
+ })
+ .collect::<PyResult<Vec<_>>>()?;
+
+ let converted_metadata: Vec<_> = metadata
+ .iter()
+ .map(|(name, value)| (name.as_str(), value))
+ .collect();
+
+ let converted_tensors: Vec<_> = tensors
+ .iter()
+ .map(|(name, tensor)| (name.as_str(), tensor.as_ref()))
+ .collect();
+
+ let mut file = std::fs::File::create(path)?;
+
+ gguf_file::write(&mut file, &converted_metadata, &converted_tensors).map_err(wrap_err)
+}
+
+#[pyfunction]
/// Returns true if the 'cuda' backend is available.
+/// &RETURNS&: bool
fn cuda_is_available() -> bool {
::candle::utils::cuda_is_available()
}
#[pyfunction]
/// Returns true if candle was compiled with 'accelerate' support.
+/// &RETURNS&: bool
fn has_accelerate() -> bool {
::candle::utils::has_accelerate()
}
#[pyfunction]
/// Returns true if candle was compiled with MKL support.
+/// &RETURNS&: bool
fn has_mkl() -> bool {
::candle::utils::has_mkl()
}
#[pyfunction]
/// Returns the number of threads used by the candle.
+/// &RETURNS&: int
fn get_num_threads() -> usize {
::candle::utils::get_num_threads()
}
@@ -855,6 +1112,7 @@ fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(has_mkl, m)?)?;
m.add_function(wrap_pyfunction!(load_ggml, m)?)?;
m.add_function(wrap_pyfunction!(load_gguf, m)?)?;
+ m.add_function(wrap_pyfunction!(save_gguf, m)?)?;
m.add_function(wrap_pyfunction!(load_safetensors, m)?)?;
m.add_function(wrap_pyfunction!(save_safetensors, m)?)?;
Ok(())
@@ -862,7 +1120,8 @@ fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
#[pyfunction]
#[pyo3(text_signature = "(tensor:Tensor, dim:int)")]
-/// Applies the Softmax function to a given tensor.
+/// Applies the Softmax function to a given tensor.#
+/// &RETURNS&: Tensor
fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> {
let dim = actual_dim(&tensor, dim).map_err(wrap_err)?;
let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_err)?;
@@ -872,6 +1131,7 @@ fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> {
#[pyfunction]
#[pyo3(text_signature = "(tensor:Tensor)")]
/// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
+/// &RETURNS&: Tensor
fn silu(tensor: PyTensor) -> PyResult<PyTensor> {
let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_err)?;
Ok(PyTensor(s))
diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py
index b5b9256f..149715c2 100644
--- a/candle-pyo3/stub.py
+++ b/candle-pyo3/stub.py
@@ -13,8 +13,8 @@ TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from os import PathLike
"""
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n"
-CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType\n"
-
+CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n"
+RETURN_TYPE_MARKER = "&RETURNS&: "
def do_indent(text: Optional[str], indent: str):
@@ -28,11 +28,26 @@ def function(obj, indent:str, text_signature:str=None):
text_signature = obj.__text_signature__
text_signature = text_signature.replace("$self", "self").lstrip().rstrip()
+ doc_string = obj.__doc__
+ if doc_string is None:
+ doc_string = ""
+
+ # Check if we have a return type annotation in the docstring
+ return_type = None
+ doc_lines = doc_string.split("\n")
+ if doc_lines[-1].lstrip().startswith(RETURN_TYPE_MARKER):
+ # Extract the return type and remove it from the docstring
+ return_type = doc_lines[-1].lstrip()[len(RETURN_TYPE_MARKER):].strip()
+ doc_string = "\n".join(doc_lines[:-1])
+
string = ""
- string += f"{indent}def {obj.__name__}{text_signature}:\n"
+ if return_type:
+ string += f"{indent}def {obj.__name__}{text_signature} -> {return_type}:\n"
+ else:
+ string += f"{indent}def {obj.__name__}{text_signature}:\n"
indent += INDENT
string += f'{indent}"""\n'
- string += f"{indent}{do_indent(obj.__doc__, indent)}\n"
+ string += f"{indent}{do_indent(doc_string, indent)}\n"
string += f'{indent}"""\n'
string += f"{indent}pass\n"
string += "\n"
diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py
index c78ffc41..7f24b49d 100644
--- a/candle-pyo3/test.py
+++ b/candle-pyo3/test.py
@@ -1,5 +1,4 @@
import candle
-from candle import Tensor, QTensor
t = candle.Tensor(42.0)
print(t)
@@ -10,7 +9,7 @@ t = candle.Tensor([3.0, 1, 4, 1, 5, 9, 2, 6])
print(t)
print(t+t)
-t:Tensor = t.reshape([2, 4])
+t = t.reshape([2, 4])
print(t.matmul(t.t()))
print(t.to_dtype(candle.u8))
@@ -21,7 +20,7 @@ print(t)
print(t.dtype)
t = candle.randn((16, 256))
-quant_t:QTensor = t.quantize("q6k")
-dequant_t:Tensor = quant_t.dequantize()
-diff2:Tensor = (t - dequant_t).sqr()
+quant_t = t.quantize("q6k")
+dequant_t = quant_t.dequantize()
+diff2 = (t - dequant_t).sqr()
print(diff2.mean_all())