diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/backprop.rs | 2 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 6 | ||||
-rw-r--r-- | candle-core/src/variable.rs | 4 |
3 files changed, 8 insertions, 4 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index c152f31f..e7e3e129 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -175,7 +175,7 @@ impl Tensor { // the backprop graph of the backprop itself. This would be an issue for second order // derivatives but these are out of scope at the moment. let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b); - let grad = if do_not_detach { grad } else { grad.detach()? }; + let grad = if do_not_detach { grad } else { grad.detach() }; if let Some(op) = node.op() { match op { Op::Binary(lhs, rhs, BinaryOp::Add) => { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 5f0b6df9..8596c957 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1882,9 +1882,9 @@ impl Tensor { /// this new node. The storage of this tensor is shared with the initial tensor. /// /// If the tensor is already detached from the computation graph, the same tensor is returned. - pub fn detach(&self) -> Result<Tensor> { + pub fn detach(&self) -> Tensor { if self.op.is_none() && !self.is_variable { - Ok(self.clone()) + self.clone() } else { let tensor_ = Tensor_ { id: TensorId::new(), @@ -1895,7 +1895,7 @@ impl Tensor { dtype: self.dtype, device: self.device.clone(), }; - Ok(Tensor(Arc::new(tensor_))) + Tensor(Arc::new(tensor_)) } } diff --git a/candle-core/src/variable.rs b/candle-core/src/variable.rs index 61800bf3..bdf8da4a 100644 --- a/candle-core/src/variable.rs +++ b/candle-core/src/variable.rs @@ -107,6 +107,10 @@ impl Var { Ok(Self(inner)) } + pub fn as_detached_tensor(&self) -> Tensor { + self.0.detach() + } + pub fn as_tensor(&self) -> &Tensor { &self.0 } |