summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/backprop.rs2
-rw-r--r--candle-core/src/tensor.rs6
-rw-r--r--candle-core/src/variable.rs4
-rw-r--r--candle-examples/examples/reinforcement-learning/ddpg.rs2
-rw-r--r--candle-examples/examples/reinforcement-learning/policy_gradient.rs8
-rw-r--r--candle-nn/src/batch_norm.rs14
-rw-r--r--candle-pyo3/py_src/candle/__init__.pyi70
-rw-r--r--candle-pyo3/py_src/candle/nn/container.py6
-rw-r--r--candle-pyo3/py_src/candle/nn/module.py12
-rw-r--r--candle-pyo3/py_src/candle/nn/normalization.py1
-rw-r--r--candle-pyo3/py_src/candle/onnx/__init__.pyi11
-rw-r--r--candle-pyo3/src/lib.rs4
-rw-r--r--candle-pyo3/stub.py1
-rw-r--r--candle-wasm-examples/segment-anything/README.md3
14 files changed, 117 insertions, 27 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index c152f31f..e7e3e129 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -175,7 +175,7 @@ impl Tensor {
// the backprop graph of the backprop itself. This would be an issue for second order
// derivatives but these are out of scope at the moment.
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
- let grad = if do_not_detach { grad } else { grad.detach()? };
+ let grad = if do_not_detach { grad } else { grad.detach() };
if let Some(op) = node.op() {
match op {
Op::Binary(lhs, rhs, BinaryOp::Add) => {
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 5f0b6df9..8596c957 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1882,9 +1882,9 @@ impl Tensor {
/// this new node. The storage of this tensor is shared with the initial tensor.
///
/// If the tensor is already detached from the computation graph, the same tensor is returned.
- pub fn detach(&self) -> Result<Tensor> {
+ pub fn detach(&self) -> Tensor {
if self.op.is_none() && !self.is_variable {
- Ok(self.clone())
+ self.clone()
} else {
let tensor_ = Tensor_ {
id: TensorId::new(),
@@ -1895,7 +1895,7 @@ impl Tensor {
dtype: self.dtype,
device: self.device.clone(),
};
- Ok(Tensor(Arc::new(tensor_)))
+ Tensor(Arc::new(tensor_))
}
}
diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs
index 61800bf3..bdf8da4a 100644
--- a/candle-core/src/variable.rs
+++ b/candle-core/src/variable.rs
@@ -107,6 +107,10 @@ impl Var {
Ok(Self(inner))
}
+ pub fn as_detached_tensor(&self) -> Tensor {
+ self.0.detach()
+ }
+
pub fn as_tensor(&self) -> &Tensor {
&self.0
}
diff --git a/candle-examples/examples/reinforcement-learning/ddpg.rs b/candle-examples/examples/reinforcement-learning/ddpg.rs
index 1ce4889e..5309eaf6 100644
--- a/candle-examples/examples/reinforcement-learning/ddpg.rs
+++ b/candle-examples/examples/reinforcement-learning/ddpg.rs
@@ -411,7 +411,7 @@ impl DDPG<'_> {
pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
let actions = self
.actor
- .forward(&state.detach()?.unsqueeze(0)?)?
+ .forward(&state.detach().unsqueeze(0)?)?
.squeeze(0)?;
let actions = if self.train {
(actions + self.ou_noise.sample()?)?
diff --git a/candle-examples/examples/reinforcement-learning/policy_gradient.rs b/candle-examples/examples/reinforcement-learning/policy_gradient.rs
index 044cbfcd..6c355fe6 100644
--- a/candle-examples/examples/reinforcement-learning/policy_gradient.rs
+++ b/candle-examples/examples/reinforcement-learning/policy_gradient.rs
@@ -74,7 +74,7 @@ pub fn run() -> Result<()> {
loop {
let action = {
let action_probs: Vec<f32> =
- softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)?
+ softmax(&model.forward(&state.detach().unsqueeze(0)?)?, 1)?
.squeeze(0)?
.to_vec1()?;
weighted_sample(action_probs, &mut rng)? as i64
@@ -109,7 +109,7 @@ pub fn run() -> Result<()> {
let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)?
.to_dtype(DType::F32)?
- .detach()?;
+ .detach();
let actions_mask = {
let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
@@ -126,12 +126,12 @@ pub fn run() -> Result<()> {
.unwrap()
})
.collect();
- Tensor::stack(&actions_mask, 0)?.detach()?
+ Tensor::stack(&actions_mask, 0)?.detach()
};
let states = {
let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect();
- Tensor::stack(&states, 0)?.detach()?
+ Tensor::stack(&states, 0)?.detach()
};
let log_probs = actions_mask
diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs
index 856c2c7a..4c67961d 100644
--- a/candle-nn/src/batch_norm.rs
+++ b/candle-nn/src/batch_norm.rs
@@ -262,9 +262,19 @@ impl BatchNorm {
let target_shape = target_shape.as_slice();
let x = x
- .broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)?
+ .broadcast_sub(
+ &self
+ .running_mean
+ .as_detached_tensor()
+ .reshape(target_shape)?,
+ )?
.broadcast_div(
- &(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?,
+ &(self
+ .running_var
+ .as_detached_tensor()
+ .reshape(target_shape)?
+ + self.eps)?
+ .sqrt()?,
)?;
match &self.weight_and_bias {
diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi
index 37b8fe8c..aef0707d 100644
--- a/candle-pyo3/py_src/candle/__init__.pyi
+++ b/candle-pyo3/py_src/candle/__init__.pyi
@@ -88,23 +88,27 @@ class QTensor:
Dequantizes the tensor.
"""
pass
+
@property
def ggml_dtype(self) -> str:
"""
Gets the tensors quantized dtype.
"""
pass
+
def matmul_t(self, lhs: Tensor) -> Tensor:
"""
Performs a quantized matrix multiplication, with the quantized tensor as the right hand side.
"""
pass
+
@property
def rank(self) -> int:
"""
Gets the rank of the tensor.
"""
pass
+
@property
def shape(self) -> Tuple[int]:
"""
@@ -119,178 +123,213 @@ class Tensor:
def __init__(self, data: _ArrayLike):
pass
+
def __add__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Add a scalar to a tensor or two tensors together.
"""
pass
+
def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
+
def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
+
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
"""
Return a slice of a tensor.
"""
pass
+
def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
+
def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
+
def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
+
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Multiply a tensor by a scalar or one tensor by another.
"""
pass
+
def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
+
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Add a scalar to a tensor or two tensors together.
"""
pass
+
def __richcmp__(self, rhs: Union[Tensor, Scalar], op) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
+
def __rmul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Multiply a tensor by a scalar or one tensor by another.
"""
pass
+
def __sub__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Subtract a scalar from a tensor or one tensor from another.
"""
pass
+
def __truediv__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Divide a tensor by a scalar or one tensor by another.
"""
pass
+
def abs(self) -> Tensor:
"""
Performs the `abs` operation on the tensor.
"""
pass
+
def argmax_keepdim(self, dim: int) -> Tensor:
"""
Returns the indices of the maximum value(s) across the selected dimension.
"""
pass
+
def argmin_keepdim(self, dim: int) -> Tensor:
"""
Returns the indices of the minimum value(s) across the selected dimension.
"""
pass
+
def broadcast_add(self, rhs: Tensor) -> Tensor:
"""
Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
"""
pass
+
def broadcast_as(self, *shape: Shape) -> Tensor:
"""
Broadcasts the tensor to the given shape.
"""
pass
+
def broadcast_div(self, rhs: Tensor) -> Tensor:
"""
Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
"""
pass
+
def broadcast_left(self, *shape: Shape) -> Tensor:
"""
Broadcasts the tensor to the given shape, adding new dimensions on the left.
"""
pass
+
def broadcast_mul(self, rhs: Tensor) -> Tensor:
"""
Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
"""
pass
+
def broadcast_sub(self, rhs: Tensor) -> Tensor:
"""
Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
"""
pass
+
def contiguous(self) -> Tensor:
"""
Makes the tensor contiguous in memory.
"""
pass
+
def copy(self) -> Tensor:
"""
Returns a copy of the tensor.
"""
pass
+
def cos(self) -> Tensor:
"""
Performs the `cos` operation on the tensor.
"""
pass
+
def detach(self) -> Tensor:
"""
Detach the tensor from the computation graph.
"""
pass
+
@property
def device(self) -> Device:
"""
Gets the tensor's device.
"""
pass
+
@property
def dtype(self) -> DType:
"""
Gets the tensor's dtype.
"""
pass
+
def exp(self) -> Tensor:
"""
Performs the `exp` operation on the tensor.
"""
pass
+
def flatten_all(self) -> Tensor:
"""
Flattens the tensor into a 1D tensor.
"""
pass
+
def flatten_from(self, dim: int) -> Tensor:
"""
Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension.
"""
pass
+
def flatten_to(self, dim: int) -> Tensor:
"""
Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive).
"""
pass
+
def get(self, index: int) -> Tensor:
"""
Gets the value at the specified index.
"""
pass
+
def index_select(self, rhs: Tensor, dim: int) -> Tensor:
"""
Select values for the input tensor at the target indexes across the specified dimension.
@@ -302,161 +341,192 @@ class Tensor:
tensor.
"""
pass
+
def is_contiguous(self) -> bool:
"""
Returns true if the tensor is contiguous in C order.
"""
pass
+
def is_fortran_contiguous(self) -> bool:
"""
Returns true if the tensor is contiguous in Fortran order.
"""
pass
+
def log(self) -> Tensor:
"""
Performs the `log` operation on the tensor.
"""
pass
+
def matmul(self, rhs: Tensor) -> Tensor:
"""
Performs a matrix multiplication between the two tensors.
"""
pass
+
def max_keepdim(self, dim: int) -> Tensor:
"""
Gathers the maximum value across the selected dimension.
"""
pass
+
def mean_all(self) -> Tensor:
"""
Returns the mean of the tensor.
"""
pass
+
def min_keepdim(self, dim: int) -> Tensor:
"""
Gathers the minimum value across the selected dimension.
"""
pass
+
def narrow(self, dim: int, start: int, len: int) -> Tensor:
"""
Returns a new tensor that is a narrowed version of the input, the dimension `dim`
ranges from `start` to `start + len`.
"""
pass
+
@property
def nelement(self) -> int:
"""
Gets the tensor's element count.
"""
pass
+
def powf(self, p: float) -> Tensor:
"""
Performs the `pow` operation on the tensor with the given exponent.
"""
pass
+
def quantize(self, quantized_dtype: str) -> QTensor:
"""
Quantize the tensor.
"""
pass
+
@property
def rank(self) -> int:
"""
Gets the tensor's rank.
"""
pass
+
def recip(self) -> Tensor:
"""
Get the `recip` of the tensor.
"""
pass
+
def reshape(self, *shape: Shape) -> Tensor:
"""
Reshapes the tensor to the given shape.
"""
pass
+
@property
def shape(self) -> Tuple[int]:
"""
Gets the tensor's shape.
"""
pass
+
def sin(self) -> Tensor:
"""
Performs the `sin` operation on the tensor.
"""
pass
+
def sqr(self) -> Tensor:
"""
Squares the tensor.
"""
pass
+
def sqrt(self) -> Tensor:
"""
Calculates the square root of the tensor.
"""
pass
+
def squeeze(self, dim: int) -> Tensor:
"""
Creates a new tensor with the specified dimension removed if its size was one.
"""
pass
+
@property
def stride(self) -> Tuple[int]:
"""
Gets the tensor's strides.
"""
pass
+
def sum_all(self) -> Tensor:
"""
Returns the sum of the tensor.
"""
pass
+
def sum_keepdim(self, dim: Union[int, List[int]]) -> Tensor:
"""
Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions.
"""
pass
+
def t(self) -> Tensor:
"""
Transposes the tensor.
"""
pass
+
def to(self, *args, **kwargs) -> Tensor:
"""
Performs Tensor dtype and/or device conversion.
"""
pass
+
def to_device(self, device: Union[str, Device]) -> Tensor:
"""
Move the tensor to a new device.
"""
pass
+
def to_dtype(self, dtype: Union[str, DType]) -> Tensor:
"""
Convert the tensor to a new dtype.
"""
pass
+
def to_torch(self) -> torch.Tensor:
"""
Converts candle's tensor to pytorch's tensor
"""
pass
+
def transpose(self, dim1: int, dim2: int) -> Tensor:
"""
Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
"""
pass
+
def unsqueeze(self, dim: int) -> Tensor:
"""
Creates a new tensor with a dimension of size one inserted at the specified position.
"""
pass
+
def values(self) -> _ArrayLike:
"""
Gets the tensor's data as a Python scalar or array-like object.
"""
pass
+
def where_cond(self, on_true: Tensor, on_false: Tensor) -> Tensor:
"""
Returns a tensor with the same shape as the input tensor, the values are taken from
diff --git a/candle-pyo3/py_src/candle/nn/container.py b/candle-pyo3/py_src/candle/nn/container.py
index 6ece31b6..963a8a4a 100644
--- a/candle-pyo3/py_src/candle/nn/container.py
+++ b/candle-pyo3/py_src/candle/nn/container.py
@@ -57,12 +57,10 @@ class Sequential(Module):
_modules: Dict[str, Module] # type: ignore[assignment]
@overload
- def __init__(self, *args: Module) -> None:
- ...
+ def __init__(self, *args: Module) -> None: ...
@overload
- def __init__(self, arg: "OrderedDict[str, Module]") -> None:
- ...
+ def __init__(self, arg: "OrderedDict[str, Module]") -> None: ...
def __init__(self, *args):
super().__init__()
diff --git a/candle-pyo3/py_src/candle/nn/module.py b/candle-pyo3/py_src/candle/nn/module.py
index 514d92b8..972d9a91 100644
--- a/candle-pyo3/py_src/candle/nn/module.py
+++ b/candle-pyo3/py_src/candle/nn/module.py
@@ -204,12 +204,10 @@ class Module:
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
@overload
- def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination:
- ...
+ def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: ...
@overload
- def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]:
- ...
+ def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: ...
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
r"""Returns a dictionary containing references to the whole state of the module.
@@ -586,12 +584,10 @@ class Module:
self: T,
device: str = ...,
dtype: Optional[Union[DType, str]] = ...,
- ) -> T:
- ...
+ ) -> T: ...
@overload
- def to(self: T, dtype: Union[DType, str]) -> T:
- ...
+ def to(self: T, dtype: Union[DType, str]) -> T: ...
def to(self, *args, **kwargs):
r"""Moves and/or casts the parameters and buffers.
diff --git a/candle-pyo3/py_src/candle/nn/normalization.py b/candle-pyo3/py_src/candle/nn/normalization.py
index 67510a24..61d29c51 100644
--- a/candle-pyo3/py_src/candle/nn/normalization.py
+++ b/candle-pyo3/py_src/candle/nn/normalization.py
@@ -14,6 +14,7 @@ class LayerNorm(Module):
math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
"""
+
__constants__ = ["normalized_shape", "eps"]
normalized_shape: Tuple[int, ...]
eps: float
diff --git a/candle-pyo3/py_src/candle/onnx/__init__.pyi b/candle-pyo3/py_src/candle/onnx/__init__.pyi
index 8ce1b3aa..a23cd2f0 100644
--- a/candle-pyo3/py_src/candle/onnx/__init__.pyi
+++ b/candle-pyo3/py_src/candle/onnx/__init__.pyi
@@ -11,59 +11,69 @@ class ONNXModel:
def __init__(self, path: str):
pass
+
@property
def doc_string(self) -> str:
"""
The doc string of the model.
"""
pass
+
@property
def domain(self) -> str:
"""
The domain of the operator set of the model.
"""
pass
+
def initializers(self) -> Dict[str, Tensor]:
"""
Get the weights of the model.
"""
pass
+
@property
def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
"""
The inputs of the model.
"""
pass
+
@property
def ir_version(self) -> int:
"""
The version of the IR this model targets.
"""
pass
+
@property
def model_version(self) -> int:
"""
The version of the model.
"""
pass
+
@property
def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
"""
The outputs of the model.
"""
pass
+
@property
def producer_name(self) -> str:
"""
The producer of the model.
"""
pass
+
@property
def producer_version(self) -> str:
"""
The version of the producer of the model.
"""
pass
+
def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Run the model on the given inputs.
@@ -81,6 +91,7 @@ class ONNXTensorDescription:
The data type of the tensor.
"""
pass
+
@property
def shape(self) -> Tuple[Union[int, str, Any]]:
"""
diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs
index ca406876..7b9a7413 100644
--- a/candle-pyo3/src/lib.rs
+++ b/candle-pyo3/src/lib.rs
@@ -938,8 +938,8 @@ impl PyTensor {
/// Detach the tensor from the computation graph.
/// &RETURNS&: Tensor
- fn detach(&self) -> PyResult<Self> {
- Ok(PyTensor(self.0.detach().map_err(wrap_err)?))
+ fn detach(&self) -> Self {
+ PyTensor(self.0.detach())
}
/// Returns a copy of the tensor.
diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py
index c459ebb3..165941bd 100644
--- a/candle-pyo3/stub.py
+++ b/candle-pyo3/stub.py
@@ -189,7 +189,6 @@ def do_black(content, is_pyi):
line_length=119,
is_pyi=is_pyi,
string_normalization=True,
- experimental_string_processing=False,
)
try:
return black.format_file_contents(content, fast=True, mode=mode)
diff --git a/candle-wasm-examples/segment-anything/README.md b/candle-wasm-examples/segment-anything/README.md
index 04ff2033..f8d8ad9d 100644
--- a/candle-wasm-examples/segment-anything/README.md
+++ b/candle-wasm-examples/segment-anything/README.md
@@ -1,6 +1,7 @@
## Running Segment Anything Example
-Here, we provide two examples of how to run Whisper using a Candle-compiled WASM binary and runtimes.
+Here, we provide an example showing how to run the Segment Anything model in the
+browser.
### Vanilla JS and WebWorkers